Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
d6081b04
Commit
d6081b04
authored
Jul 08, 2021
by
Daniel Povey
Browse files
Test CPU derivative code
parent
c80ebba6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
29 deletions
+92
-29
torch_learned_nonlin/learned_nonlin_cpu.cpp
torch_learned_nonlin/learned_nonlin_cpu.cpp
+49
-29
torch_learned_nonlin/learned_nonlin_test.py
torch_learned_nonlin/learned_nonlin_test.py
+43
-0
No files found.
torch_learned_nonlin/learned_nonlin_cpu.cpp
View file @
d6081b04
...
@@ -60,8 +60,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
...
@@ -60,8 +60,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
// so in a sense -K and +K are not special, but we include those
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
// that are < -(K-1) and > (K-1)
scalar_t
x
=
input_a
[
b
][
c
][
t
]
*
inv_scale
+
K
,
scalar_t
x
=
input_a
[
b
][
c
][
t
]
*
inv_scale
+
K
;
y
;
int
min
=
0
,
diff
=
K
;
int
min
=
0
,
diff
=
K
;
while
(
diff
>
0
)
{
while
(
diff
>
0
)
{
int
mid
=
min
+
diff
;
int
mid
=
min
+
diff
;
...
@@ -70,7 +69,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
...
@@ -70,7 +69,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
diff
=
diff
>>
1
;
diff
=
diff
>>
1
;
}
}
// OK, at this point, 0 <= min < 2*K.
// OK, at this point, 0 <= min < 2*K.
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
min
+
1
]
+
y_vals_a
[
c
][
min
];
scalar_t
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
min
+
1
]
+
y_vals_a
[
c
][
min
];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
output_a
[
b
][
c
][
t
]
=
y
*
scale
;
output_a
[
b
][
c
][
t
]
=
y
*
scale
;
...
@@ -116,18 +115,21 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -116,18 +115,21 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"learned_nonlin_backward_cpu_loop"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"learned_nonlin_backward_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
params_grad_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
params_grad_a
=
params
_grad
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
(),
y_vals_grad_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
();
y_vals_grad_a
=
y_vals
_grad
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
;
sum_positive
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals_a
[
c
][
K
-
1
+
i
]
=
sum_positive
;
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
;
y_vals_a
[
c
][
K
-
1
-
i
]
=
sum_negative
;
y_vals_a
[
c
][
K
-
i
]
=
sum_negative
;
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
];
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
];
sum_negative
-=
params_a
[
c
][
K
-
i
];
sum_negative
-=
params_a
[
c
][
K
-
i
];
}
}
// the reference point for the lowest, half-infinite interval (the one
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals.
y_vals_a
[
c
][
0
]
=
y_vals_a
[
c
][
1
];
}
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
...
@@ -147,10 +149,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -147,10 +149,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
// so in a sense -K and +K are not special, but we include those
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
// that are < -(K-1) and > (K-1)
scalar_t
x
=
input_a
[
b
][
c
][
t
]
*
inv_scale
+
K
,
scalar_t
input
=
input_a
[
b
][
c
][
t
],
output_grad
=
output_grad_a
[
b
][
c
][
t
],
x
=
input
*
inv_scale
+
K
,
x_grad
,
output_grad
=
output_grad_a
[
b
][
c
][
t
];
y
;
int
min
=
0
,
diff
=
K
;
int
min
=
0
,
diff
=
K
;
while
(
diff
>
0
)
{
while
(
diff
>
0
)
{
int
mid
=
min
+
diff
;
int
mid
=
min
+
diff
;
...
@@ -159,26 +160,45 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -159,26 +160,45 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
diff
=
diff
>>
1
;
diff
=
diff
>>
1
;
}
}
// OK, at this point, 0 <= min < 2*K.
// OK, at this point, 0 <= min < 2*K.
// The "+ 1" is to get (input_a[b][c][t] * inv_scale) - (-(K+1))
scalar_t
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
min
+
1
]
+
y_vals_a
[
c
][
min
];
if
(
min
==
0
)
{
// backprop for: output_a[b][c][t] = y * scale;
y
=
(
x
+
1
)
*
params_a
[
c
][
1
]
+
y_vals_a
[
c
][
0
];
// output_a[b][c][t] = y * scale;
scale_grad
+=
y
*
output_grad
;
scale_grad
+=
y
*
output_grad
;
scalar_t
y_grad
=
scale
*
output_grad
;
scalar_t
y_grad
=
scale
*
output_grad
;
x_grad
=
y_grad
*
params_a
[
c
][
1
];
// backprop for:
//y_vals_grad_a[c][0] +=
// scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min];
}
else
{
scalar_t
x_grad
=
y_grad
*
params_a
[
c
][
min
+
1
];
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
min
+
1
]
+
y_vals_a
[
c
][
min
-
1
];
params_grad_a
[
c
][
min
+
1
]
+=
y_grad
*
(
x
-
(
scalar_t
)
min
);
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
y_vals_grad_a
[
c
][
min
]
+=
y_grad
;
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
// backprop for: x = input * inv_scale + K,
inv_scale_grad
+=
x_grad
*
input
;
input_grad_a
[
b
][
c
][
t
]
=
x_grad
*
inv_scale
;
}
// Do the backprop to l as if we had done:
// scale = exp(l); inv_scale = exp(-l);
scalar_t
l_grad
=
scale
*
scale_grad
-
inv_scale
*
inv_scale_grad
;
params_grad_a
[
c
][
0
]
+=
l_grad
;
}
}
//output_a[b][c][t] = y * scale;
}
}
// Now do the backprop for the loop above where we set y_vals_a.
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
// backprop for: y_vals_a[c][0] = y_vals_a[c][1];
y_vals_grad_a
[
c
][
1
]
+=
y_vals_grad_a
[
c
][
0
];
scalar_t
sum_negative_grad
=
0.0
,
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// backprop for: sum_negative -= params_a[c][K - i];
params_grad_a
[
c
][
K
-
i
]
-=
sum_negative_grad
;
// backprop for: sum_positive += params_a[c][1 + K + i];
params_grad_a
[
c
][
1
+
K
+
i
]
+=
sum_positive_grad
;
// backprop for: y_vals_a[c][K - i] = sum_negative;
sum_negative_grad
+=
y_vals_grad_a
[
c
][
K
-
i
];
// backprop for: y_vals_a[c][K + i] = sum_positive;
sum_positive_grad
+=
y_vals_grad_a
[
c
][
K
+
i
];
}
}
}
}));
}
//return output
;
}))
;
//
return std::vector<torch::Tensor>({
grad_
input
,
grad
_pos_add, grad_pos_mul
});
return
std
::
vector
<
torch
::
Tensor
>
({
input
_
grad
,
params_grad
});
}
}
...
...
torch_learned_nonlin/learned_nonlin_test.py
View file @
d6081b04
...
@@ -14,12 +14,54 @@ def test_learned_nonlin_basic():
...
@@ -14,12 +14,54 @@ def test_learned_nonlin_basic():
K
=
4
K
=
4
N
=
K
*
2
N
=
K
*
2
params
=
torch
.
arange
(
N
+
1
,
dtype
=
dtype
).
unsqueeze
(
0
)
+
torch
.
arange
(
C
,
dtype
=
dtype
).
unsqueeze
(
1
)
params
=
torch
.
arange
(
N
+
1
,
dtype
=
dtype
).
unsqueeze
(
0
)
+
torch
.
arange
(
C
,
dtype
=
dtype
).
unsqueeze
(
1
)
x
.
requires_grad
=
True
params
.
requires_grad
=
True
print
(
"x = "
,
x
)
print
(
"x = "
,
x
)
print
(
"params = "
,
params
)
print
(
"params = "
,
params
)
print
(
"x.shape = "
,
x
.
shape
)
print
(
"x.shape = "
,
x
.
shape
)
y
=
learned_nonlin
(
x
,
params
,
dim
=
1
)
y
=
learned_nonlin
(
x
,
params
,
dim
=
1
)
print
(
"y = "
,
y
)
print
(
"y = "
,
y
)
y
.
sum
().
backward
()
print
(
"x.grad = "
,
x
.
grad
)
print
(
"params.grad = "
,
params
.
grad
)
# Just eyeballing the above tgo make sure it looks reasonable.
def
test_learned_nonlin_deriv
():
""" Tests derivatives in randomized way """
for
_
in
range
(
10
):
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
B
=
random
.
randrange
(
1
,
10
)
C
=
random
.
randrange
(
1
,
10
)
T
=
random
.
randrange
(
1
,
20
)
x
=
torch
.
randn
(
B
,
C
,
T
,
dtype
=
dtype
)
K
=
2
**
random
.
randrange
(
0
,
4
)
N
=
K
*
2
params
=
torch
.
randn
(
C
,
N
+
1
,
dtype
=
dtype
)
x
.
requires_grad
=
True
params
.
requires_grad
=
True
print
(
f
"B,C,T,K =
{
B
}
,
{
C
}
,
{
T
}
,
{
K
}
"
)
y
=
learned_nonlin
(
x
,
params
,
dim
=
1
)
y_deriv
=
torch
.
rand_like
(
y
)
y
.
backward
(
gradient
=
y_deriv
)
delta
=
1.0e-04
delta_x
=
torch
.
randn_like
(
x
)
*
delta
pred_change
=
(
x
.
grad
*
delta_x
).
sum
()
observed_change
=
(
y_deriv
*
(
learned_nonlin
(
x
+
delta_x
,
params
,
dim
=
1
)
-
y
)).
sum
()
print
(
f
"for input: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
assert
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
1.0e-02
,
atol
=
1.0e-05
)
delta_params
=
torch
.
randn_like
(
params
)
*
delta
pred_change
=
(
params
.
grad
*
delta_params
).
sum
()
observed_change
=
(
y_deriv
*
(
learned_nonlin
(
x
,
params
+
delta_params
,
dim
=
1
)
-
y
)).
sum
()
print
(
f
"for params: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
assert
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
1.0e-02
,
atol
=
1.0e-05
)
...
@@ -225,6 +267,7 @@ def test_learned_nonlin_rand_grad():
...
@@ -225,6 +267,7 @@ def test_learned_nonlin_rand_grad():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_learned_nonlin_basic
()
test_learned_nonlin_basic
()
test_learned_nonlin_deriv
()
if
False
:
if
False
:
test_learned_nonlin_rand_grad
()
test_learned_nonlin_rand_grad
()
test_learned_nonlin_zeros
()
test_learned_nonlin_zeros
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment