Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
a4466946
Commit
a4466946
authored
Jul 16, 2021
by
Daniel Povey
Browse files
Deriv working in one case at least..
parent
c3e61bea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
24 deletions
+17
-24
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+16
-23
torch_learned_nonlin/learned_nonlin_test.py
torch_learned_nonlin/learned_nonlin_test.py
+1
-1
No files found.
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
a4466946
...
...
@@ -361,38 +361,32 @@ void learned_nonlin_backward_kernel(
// will be set to zero for excess threads, and thus won't contribute to
// this_params_grad or this_y_vals_grad.
for
(
int
t_offset
=
0
;
t_offset
<
T
;
t_offset
+=
THREADS_PER_BLOCK
)
{
// The following is equivalent to:
// int t = (threadIdx.x % T_inc) + t_offset;
// given that T_inc is a power of 2 and t_offset >= THREADS_PER_BLOCK >= T_inc.
int
t
=
(
threadIdx
.
x
&
(
T_inc
-
1
))
|
t_offset
;
scalar_t
this_output_grad
=
0.0
;
if
(
t
<
T
)
this_output_grad
=
output_grad
[
b
][
c
][
t
];
// The reason we use t % T here rather than only invoking this in some
// threads, is so that the un-needed threads will have a similar
// distribution over 'n' to the needed threads, which will hopefully avoid
// excessive work for some particular 'n' value if too many x values had
// the same 'n'. It might be better to set n to an invalid value for
// out-of-range threads, but as it is, if we are to properly handle
// N==16 we don't have enough bits available in `src_indexes` to do this.
scalar_t
this_input
=
input
[
b
][
c
][
t
%
T
]
*
inv_scale
+
K
;
input_buf
[
threadIdx
.
x
]
=
this_input
;
output_grad_buf
[
threadIdx
.
x
]
=
this_output_grad
;
scalar_t
x
=
this_input
;
scalar_t
this_input
=
0.0
,
this_output_grad
;
if
(
t
<
T
)
{
this_output_grad
=
output_grad
[
b
][
c
][
t
];
this_input
=
input
[
b
][
c
][
t
];
input_buf
[
threadIdx
.
x
]
=
this_input
;
output_grad_buf
[
threadIdx
.
x
]
=
this_output_grad
;
}
scalar_t
x
=
this_input
*
inv_scale
+
K
;
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
int
n
=
(
int
)
x
;
n_buf
[
threadIdx
.
x
]
=
(
char
)
n
;
// 0 <= n < N
// The forward code did:
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
if
(
t
<
T
)
if
(
t
<
T
)
{
int
n
=
(
int
)
x
;
// C++ rounds toward zero.
n_buf
[
threadIdx
.
x
]
=
(
char
)
n
;
input_grad
[
b
][
c
][
t
]
=
this_output_grad
*
params_buf
[
n
];
}
else
{
n_buf
[
threadIdx
.
x
]
=
255
;
}
int
this_block_start
=
threadIdx
.
x
&
~
(
N
-
1
),
// == N * (threadIdx.x / N),
// since N is power of 2
...
...
@@ -465,9 +459,8 @@ void learned_nonlin_backward_kernel(
}
// TODO: remove the next lines
assert
(
n_buf
[
threadIdx
.
x
]
==
0
);
assert
(
n_buf
[
threadIdx
.
x
]
==
0
||
(
unsigned
char
)
n_buf
[
threadIdx
.
x
]
==
255
);
output_grad_buf
[
threadIdx
.
x
]
=
0.0
;
}
}
...
...
torch_learned_nonlin/learned_nonlin_test.py
View file @
a4466946
...
...
@@ -90,7 +90,7 @@ def test_learned_nonlin_deriv():
y2
=
learned_nonlin
(
x
+
delta_x
,
params
,
dim
=
1
)
observed_change
=
(
y_deriv
*
(
y2
-
y
)).
sum
()
print
(
f
"for input: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
if
not
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
2.0e-02
,
atol
=
1
.0e-05
):
if
not
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
2.0e-02
,
atol
=
3
.0e-05
):
print
(
f
"For changed input, output differs too much: params=
{
params
}
, input=
{
x
}
, mod_input=
{
x
+
delta_x
}
, y=
{
y
}
, y2=
{
y2
}
, diff=
{
y2
-
y
}
"
)
assert
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment