Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
f4081496
Commit
f4081496
authored
Jul 15, 2021
by
Daniel Povey
Browse files
Work on backward kernel
parent
6a77cb45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
58 deletions
+76
-58
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+76
-58
No files found.
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
f4081496
...
...
@@ -180,7 +180,6 @@ void learned_nonlin_kernel(
output
[
b
][
c
][
t
]
=
(
x
-
n
)
*
params_buf
[
n
]
+
y_vals
[
n
];
}
}
}
...
...
@@ -297,7 +296,9 @@ void learned_nonlin_backward_kernel(
// spaces between here and
// `params_buf` for storing scale
// and inv_scale and l == params[c][0].
*
params_buf
=
(
scalar_t
*
)
y_vals
+
3
+
N
;
// [N]. Caution: contains params[c][1] through params[c][N].
*
params_buf
=
(
scalar_t
*
)
y_vals
+
3
+
N
;
// [N]. Contains parameters (not times scale!)
// Caution: contains params[c][1] through params[c][N],
// i.e. numbering is off by 1 versus params.
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
...
...
@@ -312,58 +313,56 @@ void learned_nonlin_backward_kernel(
// determines which piece of the piecewise
// linear function we are in.
// this_params_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval.
scalar_t
this_params_grad
=
0.0
,
this_y_vals_grad
=
0.0
;
// Load parameters
if
(
threadIdx
.
x
<=
N
)
params_buf
[
threadIdx
.
x
-
1
]
=
params
[
c
][
threadIdx
.
x
];
__syncthreads
();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
// This next block computes `y_vals`.
if
((((
int
)
threadIdx
.
x
&
~
(
int
)
32
))
==
0
)
{
// threadIdx.x == 0 or 32. These are in separate warps so we can
// allow them to do separate jobs. This code takes linear time in K which
// is not at all ideal and could be improved if K is largish, but it shouldn't
// dominate the total time taken if we are processing a lot of data;
// and anyway, we doubt that K will be need to be more than 4 or 8 or so,
// so the potential savings are quite small.
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
exp
(
params_buf
[
-
1
]),
inv_scale
=
1.0
/
scale
;
params_buf
[
-
2
]
=
scale
;
// both threads write these but it's OK, it's the
// same value.
params_buf
[
-
2
]
=
scale
;
params_buf
[
-
3
]
=
inv_scale
;
int
sign
,
Koffset
;
// Koffset == K for threads handling sum_positive and K - 1
// for threads handling sum_negative, see
// learned_nonlin_cpu.cpp for reference code. This would be K
// + 1 and K respectively, except our params_buf has its index
// shifted by one versus params.
if
(
threadIdx
.
x
==
0
)
{
// sum_positive
sign
=
1
;
Koffset
=
K
;
}
else
{
// threadIdx.x == 32. sum_negative.
scale
*=
-
1
;
// this is a local variable..
sign
=
-
1
;
Koffset
=
K
-
1
;
}
__syncthreads
();
scalar_t
scale
=
params_buf
[
-
2
];
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
if
(
threadIdx
.
x
==
0
)
{
scalar_t
sum_positive
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals
[
K
+
i
]
=
sum_positive
;
// versus the CPU code, the params_buf is indexed off by 1; and it already
// contains the factor "scale".
sum_positive
+=
params_buf
[
K
+
i
]
*
scale
;
}
scalar_t
sum
=
0.0
;
}
else
if
(
threadIdx
.
x
==
64
)
{
scalar_t
sum_negative
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
int
isign
=
i
*
sign
;
y_vals
[
K
+
isign
]
=
sum
*
scale
;
sum
+=
params_buf
[
Koffset
+
isign
];
y_vals
[
K
-
i
]
=
sum_negative
;
// versus the CPU code, the params_buf is indexed off by 1; and it already
// contains the factor "scale".
sum_negative
-=
params_buf
[
K
-
1
-
i
]
*
scale
;
}
if
(
threadIdx
.
x
!=
0
)
// sum_negative
y_vals
[
0
]
=
sum
*
scale
;
y_vals
[
0
]
=
sum_negative
;
}
__syncthreads
();
// this_params_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval.
// "this_params_grad" actually contains the derivative w.r.t. scaled params, i.e.
// params[n] * scale.
scalar_t
this_scaled_param_grad
=
0.0
,
this_y_vals_grad
=
0.0
;
scalar_t
inv_scale
=
params_buf
[
-
3
];
int
T_inc
=
THREADS_PER_BLOCK
/
images_per_thread_block
,
...
...
@@ -408,8 +407,11 @@ void learned_nonlin_backward_kernel(
// The forward code did:
// output[b][c][t] = (x - n) * params_buf[n] + y_vals[n];
if
(
t
<
T
)
if
(
t
<
T
)
{
// In a sense this expression should contain "* inv_scale * scale"...
// of course, their product equals 1.
input_grad
[
b
][
c
][
t
]
=
this_output_grad
*
params_buf
[
n
];
}
int
this_block_start
=
threadIdx
.
x
&
~
(
N
-
1
),
// == N * (threadIdx.x / N),
this_n
=
threadIdx
.
x
&
(
N
-
1
);
// == threadIdx.x % N.
...
...
@@ -471,9 +473,9 @@ void learned_nonlin_backward_kernel(
src_thread
=
this_block_start
+
src_idx
;
scalar_t
output_grad
=
output_grad_buf
[
src_thread
],
x_residual
=
x_residual_buf
[
src_thread
];
// Backprop for: output = x_residual * params_buf[n] + y_vals[n].
// Backprop for: output = x_residual *
(
params_buf[n]
* scale)
+ y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_param
s
_grad
+=
output_grad
*
x_residual
;
this_
scaled_
param_grad
+=
output_grad
*
x_residual
;
this_y_vals_grad
+=
output_grad
;
}
}
...
...
@@ -482,14 +484,14 @@ void learned_nonlin_backward_kernel(
__syncthreads
();
// sync threads because we are about to re-use
// output_grad_buf for reduction.
this_param
s
_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_param
s
_grad
);
this_
scaled_
param_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_
scaled_
param_grad
);
this_y_vals_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_y_vals_grad
);
__syncthreads
();
// sync threads because we are about to re-use
// output_grad_buf.
// Re-use some buffers..
scalar_t
*
params_grad_buf
=
x_residual_buf
,
// [N]
scalar_t
*
scaled_
params_grad_buf
=
x_residual_buf
,
// [N]
... a
*
y_vals_grad_buf
=
output_grad_buf
;
// [N]
if
(
threadIdx
.
x
<
N
)
{
...
...
@@ -497,7 +499,7 @@ void learned_nonlin_backward_kernel(
// the position in 'params'. To keep the backprop code similar to the CPU
// backprop code we restore that offset here, i.e. use the same layout
// as the params.
params_grad_buf
[
threadIdx
.
x
+
1
]
=
this_param
s
_grad
;
scaled_
params_grad_buf
[
threadIdx
.
x
]
=
this_
scaled_
param_grad
;
y_vals_grad_buf
[
threadIdx
.
x
]
=
this_y_vals_grad
;
}
...
...
@@ -514,31 +516,47 @@ void learned_nonlin_backward_kernel(
if
(
threadIdx
.
x
==
0
)
{
scalar_t
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// This is like the CPU code but with an offset of 1 for 'params_buf'
// versus 'params_a'.
params_grad_buf
[
1
+
K
+
i
]
+=
sum_positive_grad
*
scale
;
scale_grad
+=
sum_positive_grad
*
params_buf
[
K
+
i
];
// This is like the CPU code but with an offset of -1 for indexes into 'params_buf';
// also there is no scale because we are dealing with pre-scaled parameters.
scaled_params_grad_buf
[
K
+
i
]
+=
sum_positive_grad
;
sum_positive_grad
+=
y_vals_grad_buf
[
K
+
i
];
}
params_grad_buf
[
0
]
+=
scale
*
scale_grad
;
}
else
if
(
threadIdx
.
x
==
64
)
{
scalar_t
sum_negative_grad
=
y_vals_grad_buf
[
0
];
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// This is like the CPU code but with an offset of 1 for 'params_buf'
// versus 'params_a'.
params_grad_buf
[
K
-
i
]
-=
sum_negative_grad
*
scale
;
scale_grad
-=
sum_negative_grad
*
params_buf
[
K
-
1
-
i
];
// This is like the CPU code but with an offset of 1 for 'params_buf';
// also there is no scale because we are dealing with pre-scaled parameters.
scaled_params_grad_buf
[
K
-
1
-
i
]
-=
sum_negative_grad
;
sum_negative_grad
+=
y_vals_grad_buf
[
K
-
i
];
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
64
)
params_grad_buf
[
0
]
+=
scale
*
scale_grad
;
if
(
threadIdx
.
x
<
N
)
{
// this_scaled_param_grad is the gradient w.r.t. params_buf[n] * scale
// which is equal to params[c][n + 1] * scale.
int
n
=
threadIdx
.
x
;
scalar_t
this_scaled_param_grad
=
scaled_params_grad_buf
[
n
],
this_scale_grad
=
this_scaled_param_grad
*
params_buf
[
n
],
scale
=
params_buf
[
-
2
],
this_param_grad
=
scale
*
this_scaled_param_grad
;
// re-use x_residual_buf as 'param_grad_buf'.
x_residual_buf
[
n
+
1
]
=
this_param_grad
;
scalar_t
scale_grad
=
tiled_warp_reduce_sum
(
N
,
y_vals_grad_buf
,
this_scale_grad
);
if
(
threadIdx
.
x
==
0
)
x_residual_buf
[
0
]
=
scale_grad
*
scale
;
// deriv w.r.t. l.
}
__syncthreads
();
}
if
(
threadIdx
.
x
<=
N
)
{
params_grad
[
blockIdx
.
y
][
c
][
threadIdx
.
x
]
=
params_grad_buf
[
threadIdx
.
x
];
// note, we are re-using x_residual_buf for the params_grad.
params_grad
[
blockIdx
.
y
][
c
][
threadIdx
.
x
]
=
x_residual_buf
[
threadIdx
.
x
];
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment