Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
e0bc4029
Commit
e0bc4029
authored
Jul 16, 2021
by
Daniel Povey
Browse files
CUDA backward running (but not correctly for params grad)
parent
2ccbb505
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
110 deletions
+98
-110
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+82
-108
torch_learned_nonlin/learned_nonlin_test.py
torch_learned_nonlin/learned_nonlin_test.py
+16
-2
No files found.
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
e0bc4029
...
...
@@ -124,9 +124,6 @@ void learned_nonlin_kernel(
}
__syncthreads
();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
params_buf
[
-
2
],
sum_positive
=
0.0
;
...
...
@@ -294,65 +291,51 @@ void learned_nonlin_backward_kernel(
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
scalar_t
x_residual_buf
[
THREADS_PER_BLOCK
];
// x_residual, with 0 <=
// x_residual < 1 for interior
// regions, is the residual part
// of the scaled input, after
// subtracting the integer part.
scalar_t
input_buf
[
THREADS_PER_BLOCK
];
// input sequence
scalar_t
output_grad_buf
[
THREADS_PER_BLOCK
];
char
n_buf
[
THREADS_PER_BLOCK
];
// for each input in `input_buf`, this stores
// the integer value 0 <= n < N which
// determines which piece of the piecewise
// linear function we are in.
// Load parameters
if
(
threadIdx
.
x
<=
N
)
params_buf
[
threadIdx
.
x
-
1
]
=
params
[
c
][
threadIdx
.
x
];
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
exp
(
params_buf
[
-
1
]),
inv_scale
=
1.0
/
scale
;
scalar_t
scale
=
exp
(
params_buf
[
-
1
]);
params_buf
[
-
2
]
=
scale
;
params_buf
[
-
3
]
=
inv_
scale
;
params_buf
[
-
3
]
=
1.0
/
scale
;
}
__syncthreads
();
scalar_t
scale
=
params_buf
[
-
2
];
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
if
(
threadIdx
.
x
==
0
)
{
scalar_t
sum_positive
=
0.0
;
scalar_t
scale
=
params_buf
[
-
2
],
sum_positive
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals
[
K
+
i
]
=
sum_positive
;
// versus the CPU code, the params_buf is indexed off by 1; and it already
// contains the factor "scale".
sum_positive
+=
p
arams_buf
[
K
+
i
]
*
scale
;
// params_buf is indexed with an index one less than params.
scalar_t
pos_scaled_param
=
params_buf
[
K
+
i
]
*
scale
;
y_vals
[
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
p
os_scaled_param
;
}
}
else
if
(
threadIdx
.
x
==
64
)
{
scalar_t
sum_negative
=
0.0
;
scalar_t
scale
=
params_buf
[
-
2
],
sum_negative
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals
[
K
-
i
]
=
sum_negative
;
// versus the CPU code, the params_buf is indexed off by 1; and it already
// contains the factor "scale".
sum_negative
-=
params_buf
[
K
-
1
-
i
]
*
scale
;
scalar_t
neg_scaled_param
=
params_buf
[
K
-
i
-
1
]
*
scale
;
sum_negative
-=
neg_scaled_param
;
y_vals
[
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
}
y_vals
[
0
]
=
sum_negative
;
}
__syncthreads
();
// this_param
s
_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// this_param_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval.
// "this_params_grad" actually contains the derivative w.r.t. scaled params, i.e.
// params[n] * scale.
scalar_t
this_scaled_param_grad
=
0.0
,
scalar_t
this_param_grad
=
0.0
,
this_y_vals_grad
=
0.0
;
scalar_t
inv_scale
=
params_buf
[
-
3
];
...
...
@@ -370,7 +353,7 @@ void learned_nonlin_backward_kernel(
// this_params_grad or this_y_vals_grad.
for
(
int
t_offset
=
0
;
t_offset
<
T
;
t_offset
+=
THREADS_PER_BLOCK
)
{
int
t
=
threadIdx
.
x
%
T_inc
+
t_offset
;
scalar_t
this_output_grad
=
0.0
,
x
=
0.0
;
scalar_t
this_output_grad
=
0.0
;
if
(
t
<
T
)
this_output_grad
=
output_grad
[
b
][
c
][
t
];
...
...
@@ -381,29 +364,24 @@ void learned_nonlin_backward_kernel(
// the same 'n'. It might be better to set n to an invalid value for
// out-of-range threads, but as it is, if we are to properly handle
// N==16 we don't have enough bits available in `src_indexes` to do this.
x
=
input
[
b
][
c
][
t
%
T
]
*
inv_scale
+
K
;
scalar_t
this_input
=
input
[
b
][
c
][
t
%
T
]
*
inv_scale
+
K
;
input_buf
[
threadIdx
.
x
]
=
this_input
;
output_grad_buf
[
threadIdx
.
x
]
=
this_output_grad
;
scalar_t
x
_trunc
=
x
;
if
(
x
_trunc
<
0
)
x
_trunc
=
0
;
else
if
(
x
_trunc
>=
N
)
x
_trunc
=
N
-
1
;
scalar_t
x
=
this_input
;
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
int
n
=
(
int
)
x
_trunc
;
int
n
=
(
int
)
x
;
n_buf
[
threadIdx
.
x
]
=
(
char
)
n
;
scalar_t
x_residual
=
x
-
n
;
x_residual_buf
[
threadIdx
.
x
]
=
x_residual
;
// OK, at this point, 0 <= min < N.
// The forward code did:
// output[b][c][t] = (x - n) * params_buf[n] + y_vals[n];
if
(
t
<
T
)
{
// In a sense this expression should contain "* inv_scale * scale"...
// of course, their product equals 1.
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
if
(
t
<
T
)
input_grad
[
b
][
c
][
t
]
=
this_output_grad
*
params_buf
[
n
];
}
int
this_block_start
=
threadIdx
.
x
&
~
(
N
-
1
),
// == N * (threadIdx.x / N),
this_n
=
threadIdx
.
x
&
(
N
-
1
);
// == threadIdx.x % N.
...
...
@@ -464,10 +442,10 @@ void learned_nonlin_backward_kernel(
int
src_idx
=
src_indexes
&
0xF
,
src_thread
=
this_block_start
+
src_idx
;
scalar_t
output_grad
=
output_grad_buf
[
src_thread
],
x_residual
=
x_residual
_buf
[
src_thread
];
this_input
=
input
_buf
[
src_thread
];
// Backprop for: output = x_residual * (params_buf[n] * scale) + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_
scaled_
param_grad
+=
output_grad
*
x_residual
;
this_param_grad
+=
output_grad
*
this_input
;
this_y_vals_grad
+=
output_grad
;
}
}
...
...
@@ -476,80 +454,76 @@ void learned_nonlin_backward_kernel(
__syncthreads
();
// sync threads because we are about to re-use
// output_grad_buf for reduction.
this_
scaled_
param_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_
scaled_
param_grad
);
this_param_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_param_grad
);
this_y_vals_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_y_vals_grad
);
__syncthreads
();
// sync threads because we are about to re-use
// output_grad_buf.
// Re-use some buffers..
scalar_t
*
scaled_
params_grad_buf
=
x_residual_buf
,
// [N] ...
a
scalar_t
*
params_grad_buf
=
input_buf
+
1
,
// [N] ...
but element [-1] will have deriv of scale.
*
y_vals_grad_buf
=
output_grad_buf
;
// [N]
if
(
threadIdx
.
x
<
N
)
{
// There is an offset of 1 between the 'n' values and
// the position in 'params'. To keep the backprop code similar to the CPU
// backprop code we restore that offset here, i.e. use the same layout
// as the params.
scaled_params_grad_buf
[
threadIdx
.
x
]
=
this_scaled_param_grad
;
// Restore the indexing offset of 1 in params_grad_buf (versus
// params_buf
params_grad_buf
[
threadIdx
.
x
]
=
this_param_grad
;
y_vals_grad_buf
[
threadIdx
.
x
]
=
this_y_vals_grad
;
}
// This next block does backprop relating to `y_vals`. Comparing with the CPU
// version (call this the "reference code") is the best way to understand this (this code is just a
// modification of that).
{
// Thread 0 is responsible for parts of the reference code that involve "sum_positive_grad";
// thread 64 is responsible for parts of the reference code that involve "sum_negative_grad";
scalar_t
scale_grad
=
0.0
,
scale
=
params_buf
[
-
2
];
if
(
threadIdx
.
x
==
0
)
{
scalar_t
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// This is like the CPU code but with an offset of -1 for indexes into 'params_buf';
// also there is no scale because we are dealing with pre-scaled parameters.
scaled_params_grad_buf
[
K
+
i
]
+=
sum_positive_grad
;
sum_positive_grad
+=
y_vals_grad_buf
[
K
+
i
];
}
}
else
if
(
threadIdx
.
x
==
64
)
{
scalar_t
sum_negative_grad
=
y_vals_grad_buf
[
0
];
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// This is like the CPU code but with an offset of 1 for 'params_buf';
// also there is no scale because we are dealing with pre-scaled parameters.
scaled_params_grad_buf
[
K
-
1
-
i
]
-=
sum_negative_grad
;
sum_negative_grad
+=
y_vals_grad_buf
[
K
-
i
];
}
// version (call this the "reference code") is the best way to understand this
// (this code is just a modification of that). The main difference is we
// modify the indexes into params and params_grad by -1, so the index
// corresponds to the 'n' value; and element -1 of params_grad_buf will have
// the deriv of the log scale.
scalar_t
l_grad
;
if
(
threadIdx
.
x
==
64
)
{
// Now do the backprop for the loop above where we set y_vals_a. This could
// be further optimized to replace the loop with a raking, but I doubt this
// will have a huge effect on the runtime since K will be fairly small,
// e.g. 4.
scalar_t
scale
=
params_buf
[
-
2
],
scale_grad
=
0.0
,
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// Backprop for: sum_positive += pos_scaled_param;
scalar_t
pos_scaled_param_grad
=
sum_positive_grad
;
// Backprop for: y_vals[K + i] = sum_positive - pos_scaled_param * i;
scalar_t
y_grad_pos
=
y_vals_grad_buf
[
K
+
i
];
pos_scaled_param_grad
-=
i
*
y_grad_pos
;
sum_positive_grad
+=
y_grad_pos
;
// Backprop for: pos_scaled_param = params_buf[K + i] * scale,
params_grad_buf
[
K
+
i
]
+=
pos_scaled_param_grad
*
scale
;
scale_grad
+=
pos_scaled_param_grad
*
params_buf
[
K
+
i
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
N
)
{
// this_scaled_param_grad is the gradient w.r.t. params_buf[n] * scale
// which is equal to params[c][n + 1] * scale.
int
n
=
threadIdx
.
x
;
scalar_t
this_scaled_param_grad
=
scaled_params_grad_buf
[
n
],
this_scale_grad
=
this_scaled_param_grad
*
params_buf
[
n
],
scale
=
params_buf
[
-
2
],
this_param_grad
=
scale
*
this_scaled_param_grad
;
// re-use x_residual_buf as 'param_grad_buf'.
x_residual_buf
[
n
+
1
]
=
this_param_grad
;
scalar_t
scale_grad
=
tiled_warp_reduce_sum
(
N
,
y_vals_grad_buf
,
this_scale_grad
);
if
(
threadIdx
.
x
==
0
)
x_residual_buf
[
0
]
=
scale_grad
*
scale
;
// deriv w.r.t. l.
// Backprop for: scale = exp(l), where l = params[c][0].
params_grad_buf
[
-
1
]
=
scale
*
scale_grad
;
}
else
if
(
threadIdx
.
x
==
0
)
{
// Now do the backprop for the loop above where we set y_vals.
scalar_t
scale
=
params_buf
[
-
2
],
scale_grad
=
0.0
,
sum_negative_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// Backprop for: y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t
y_grad_neg
=
y_vals_grad_buf
[
K
-
i
-
1
];
sum_negative_grad
+=
y_grad_neg
;
scalar_t
neg_scaled_param_grad
=
y_grad_neg
*
(
i
+
1
);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad
-=
sum_negative_grad
;
// Backprop for: neg_scaled_param = params_buf[K - i - 1] * scale;
params_grad_buf
[
K
-
i
-
1
]
+=
neg_scaled_param_grad
*
scale
;
scale_grad
+=
neg_scaled_param_grad
*
params_buf
[
K
-
i
-
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
x
<=
N
)
{
// note, we are re-using x_residual_buf for the params_grad.
params_grad
[
blockIdx
.
y
][
c
][
threadIdx
.
x
]
=
x_residual_buf
[
threadIdx
.
x
];
l_grad
=
scale
*
scale_grad
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
params_grad_buf
[
-
1
]
+=
l_grad
;
// contribution to l grad from the "negative" branch
__syncthreads
();
if
(
threadIdx
.
x
<=
N
)
params_grad
[
blockIdx
.
y
][
c
][
threadIdx
.
x
]
=
params_grad_buf
[
threadIdx
.
x
-
1
];
}
...
...
torch_learned_nonlin/learned_nonlin_test.py
View file @
e0bc4029
...
...
@@ -22,17 +22,31 @@ def test_learned_nonlin_basic():
y
=
learned_nonlin
(
x
,
params
,
dim
=
1
)
print
(
"y = "
,
y
)
y
.
sum
().
backward
()
if
torch
.
cuda
.
is_available
():
# test that the CUDA forward is the same as the CPU forward.
device
=
torch
.
device
(
'cuda:0'
)
y2
=
learned_nonlin
(
x
.
to
(
device
),
params
.
to
(
device
),
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
x2
=
x
.
to
(
device
).
detach
()
x2
.
requires_grad
=
True
params2
=
params
.
to
(
device
).
detach
()
params2
.
requires_grad
=
True
y2
=
learned_nonlin
(
x2
,
params2
,
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
print
(
"Checking CUDA is same"
)
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-06
):
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
"
)
assert
(
0
);
y
.
sum
().
backward
()
y2
.
sum
().
backward
()
if
not
torch
.
allclose
(
x
.
grad
,
x2
.
grad
.
to
(
'cpu'
),
atol
=
1.0e-06
):
print
(
f
"Error: CPU x-grad versus CUDA grad not the same:
{
x
.
grad
}
vs.
{
x2
.
grad
}
, diff =
{
x2
.
grad
.
to
(
'cpu'
)
-
x
.
grad
}
"
)
assert
(
0
);
if
not
torch
.
allclose
(
params
.
grad
,
params2
.
grad
.
to
(
'cpu'
),
atol
=
1.0e-06
):
print
(
f
"Error: CPU params-grad versus CUDA grad not the same:
{
params
.
grad
}
vs.
{
params2
.
grad
}
, diff =
{
params2
.
grad
.
to
(
'cpu'
)
-
params
.
grad
}
"
)
assert
(
0
);
print
(
"x.grad = "
,
x
.
grad
)
print
(
"params.grad = "
,
params
.
grad
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment