Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
c3e61bea
Commit
c3e61bea
authored
Jul 16, 2021
by
Daniel Povey
Browse files
Nearly-working backprop on CUDA
parent
e0bc4029
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
42 deletions
+71
-42
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+70
-41
torch_learned_nonlin/learned_nonlin_test.py
torch_learned_nonlin/learned_nonlin_test.py
+1
-1
No files found.
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
c3e61bea
...
@@ -260,6 +260,14 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
...
@@ -260,6 +260,14 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
We also require that N <= THREADS_PER_BLOCK (for best performance,
We also require that N <= THREADS_PER_BLOCK (for best performance,
N should be quite small, like no larger than 8 or so).
N should be quite small, like no larger than 8 or so).
We also require 4 <= N <= 16 for this code!
We also require 4 <= N <= 16 for this code!
And we require that
N <= (THREADS_PER_BLOCK / images_per_thread_block)
(both sides will be powers of 2).. this ensures that blocks of threads
summing the N values are always within the same image, which helps
avoid a problem where some loops over 'b' would be done earlier
than others, and we'd end up counting certain pixels twice as their
output_grad would stay nonzero.
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -291,12 +299,13 @@ void learned_nonlin_backward_kernel(
...
@@ -291,12 +299,13 @@ void learned_nonlin_backward_kernel(
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
scalar_t
input_buf
[
THREADS_PER_BLOCK
];
// input sequence
__shared__
scalar_t
input_buf
[
THREADS_PER_BLOCK
];
// input sequence
scalar_t
output_grad_buf
[
THREADS_PER_BLOCK
];
__shared__
scalar_t
output_grad_buf
[
THREADS_PER_BLOCK
];
char
n_buf
[
THREADS_PER_BLOCK
];
// for each input in `input_buf`, this stores
__shared__
char
n_buf
[
THREADS_PER_BLOCK
];
// for each input in `input_buf`,
// the integer value 0 <= n < N which
// this stores the integer value 0
// determines which piece of the piecewise
// <= n < N which determines which
// linear function we are in.
// piece of the piecewise linear
// function we are in.
// Load parameters
// Load parameters
if
(
threadIdx
.
x
<=
N
)
if
(
threadIdx
.
x
<=
N
)
...
@@ -352,7 +361,12 @@ void learned_nonlin_backward_kernel(
...
@@ -352,7 +361,12 @@ void learned_nonlin_backward_kernel(
// will be set to zero for excess threads, and thus won't contribute to
// will be set to zero for excess threads, and thus won't contribute to
// this_params_grad or this_y_vals_grad.
// this_params_grad or this_y_vals_grad.
for
(
int
t_offset
=
0
;
t_offset
<
T
;
t_offset
+=
THREADS_PER_BLOCK
)
{
for
(
int
t_offset
=
0
;
t_offset
<
T
;
t_offset
+=
THREADS_PER_BLOCK
)
{
int
t
=
threadIdx
.
x
%
T_inc
+
t_offset
;
// The following is equivalent to:
// int t = (threadIdx.x % T_inc) + t_offset;
// given that T_inc is a power of 2 and t_offset >= THREADS_PER_BLOCK >= T_inc.
int
t
=
(
threadIdx
.
x
&
(
T_inc
-
1
))
|
t_offset
;
scalar_t
this_output_grad
=
0.0
;
scalar_t
this_output_grad
=
0.0
;
if
(
t
<
T
)
if
(
t
<
T
)
this_output_grad
=
output_grad
[
b
][
c
][
t
];
this_output_grad
=
output_grad
[
b
][
c
][
t
];
...
@@ -373,25 +387,24 @@ void learned_nonlin_backward_kernel(
...
@@ -373,25 +387,24 @@ void learned_nonlin_backward_kernel(
else
if
(
x
>=
N
)
x
=
N
-
1
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
// C++ rounds toward zero.
int
n
=
(
int
)
x
;
int
n
=
(
int
)
x
;
n_buf
[
threadIdx
.
x
]
=
(
char
)
n
;
n_buf
[
threadIdx
.
x
]
=
(
char
)
n
;
// 0 <= n < N
// OK, at this point, 0 <= min < N.
// The forward code did:
// The forward code did:
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
// We get the derivative for params and y_vals later.
if
(
t
<
T
)
if
(
t
<
T
)
input_grad
[
b
][
c
][
t
]
=
this_output_grad
*
params_buf
[
n
];
input_grad
[
b
][
c
][
t
]
=
this_output_grad
*
params_buf
[
n
];
int
this_block_start
=
threadIdx
.
x
&
~
(
N
-
1
),
// == N * (threadIdx.x / N),
int
this_block_start
=
threadIdx
.
x
&
~
(
N
-
1
),
// == N * (threadIdx.x / N),
// since N is power of 2
this_n
=
threadIdx
.
x
&
(
N
-
1
);
// == threadIdx.x % N.
this_n
=
threadIdx
.
x
&
(
N
-
1
);
// == threadIdx.x % N.
// this_n is the n value that this thread accumulates gradients for;
// this_n is the n value that this thread accumulates gradients for;
// it is responsible for output_grads in the block of threads
// it is responsible for output_grads in the block of threads
// from this_block_start to this_block_start+N-1.
// from this_block_start to this_block_start+N-1.
// SYNC POINT At this point there is an implicit within-warp
// __syncthreads(); // <- not really needed.
// synchronization (Note: implicit warp synchronization is considered not
// At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is not considered
// future-proof). Threads above have written to n_buf, and threads below
// future-proof). Threads above have written to n_buf, and threads below
// will read from it; but we don't need to explicitly synchronize for now
// will read from it; but we don't need to explicitly synchronize for now
// because the reads/writes are among threads in a group of N threads with
// because the reads/writes are among threads in a group of N threads with
...
@@ -399,7 +412,7 @@ void learned_nonlin_backward_kernel(
...
@@ -399,7 +412,7 @@ void learned_nonlin_backward_kernel(
// src_indexes will contain up to 16 16-bit numbers, stored starting in its
// src_indexes will contain up to 16 16-bit numbers, stored starting in its
// least significant bits. It will store all the offsets within this
// least significant bits. It will store all the offsets within this
// block of N
, where the
'n' value equals this_n.
// block of N
threads, whose chosen
'n' value equals this_n.
uint64_t
src_indexes
=
0
;
uint64_t
src_indexes
=
0
;
// num_src is the number of numbers in `src_indexes`. We need to store a
// num_src is the number of numbers in `src_indexes`. We need to store a
// separate counter because zero is a valid index and if we are to support
// separate counter because zero is a valid index and if we are to support
...
@@ -407,11 +420,12 @@ void learned_nonlin_backward_kernel(
...
@@ -407,11 +420,12 @@ void learned_nonlin_backward_kernel(
// of marker.
// of marker.
int
num_src
=
0
;
int
num_src
=
0
;
// This loop always does N statements, but they should be relatively fast
// This loop always does at least N statements, but they should be
// ones since the computation per n value is minimal and there is little
// relatively fast ones since the computation per n value is minimal and
// I/O. We are figuring out the subset of our block of N elements,
// there is little I/O. We are figuring out the subset of our block of N
// which this particular thread value is responsible for (because they
// elements, which this particular thread value is responsible for
// have n == this_n), and storing them in `src_indexes` and `num_src`.
// (because they have n == this_n), and storing them in `src_indexes` and
// `num_src`.
for
(
int
i
=
0
;
i
<
N
;
i
+=
4
)
{
for
(
int
i
=
0
;
i
<
N
;
i
+=
4
)
{
uint32_t
n_block_of_4
=
*
reinterpret_cast
<
uint32_t
*>
(
n_buf
+
this_block_start
+
i
);
uint32_t
n_block_of_4
=
*
reinterpret_cast
<
uint32_t
*>
(
n_buf
+
this_block_start
+
i
);
#pragma unroll
#pragma unroll
...
@@ -438,38 +452,45 @@ void learned_nonlin_backward_kernel(
...
@@ -438,38 +452,45 @@ void learned_nonlin_backward_kernel(
// number of images, and the hope is that different warps will reach the
// number of images, and the hope is that different warps will reach the
// end of the outer loop at around the same time because their variations
// end of the outer loop at around the same time because their variations
// in speed will average out.
// in speed will average out.
for
(;
num_src
>
0
;
--
num_src
,
src_indexes
>>=
4
)
{
for
(;
num_src
>
0
;
--
num_src
,
(
src_indexes
>>=
4
))
{
int
src_idx
=
src_indexes
&
0xF
,
int
src_thread
=
this_block_start
|
(
src_indexes
&
0xF
);
src_thread
=
this_block_start
+
src_idx
;
scalar_t
src_output_grad
=
output_grad_buf
[
src_thread
],
scalar_t
output_grad
=
output_grad_buf
[
src_thread
],
src_input
=
input_buf
[
src_thread
];
this_input
=
input_buf
[
src_thread
];
assert
(
n_buf
[
src_thread
]
==
this_n
);
// Backprop for: output = x_residual * (params_buf[n] * scale) + y_vals[n].
n_buf
[
src_thread
]
=
0
;
// Backprop for: output = input * params_buf[n] + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
// Here, n == this_n; this is how we selected these `src_idx` values.
this_param_grad
+=
output_grad
*
this
_input
;
this_param_grad
+=
src_
output_grad
*
src
_input
;
this_y_vals_grad
+=
output_grad
;
this_y_vals_grad
+=
src_
output_grad
;
}
}
// TODO: remove the next lines
assert
(
n_buf
[
threadIdx
.
x
]
==
0
);
output_grad_buf
[
threadIdx
.
x
]
=
0.0
;
}
}
}
}
__syncthreads
();
// sync threads because we are about to re-use
__syncthreads
();
// sync threads because we are about to re-use
// output_grad_buf for reduction.
// output_grad_buf for reduction
, and, later, input_buf
.
this_param_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_param_grad
);
this_param_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_param_grad
);
__syncthreads
();
this_y_vals_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_y_vals_grad
);
this_y_vals_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_y_vals_grad
);
__syncthreads
();
// sync threads because we are about to re-use
__syncthreads
();
// sync threads because we are about to re-use
// output_grad_buf.
// output_grad_buf
as y_vals_grad_buf
.
// Re-use some buffers..
// Re-use some buffers..
scalar_t
*
params_grad_buf
=
input_buf
+
1
,
// [N] ... but element [-1] will have deriv of scale.
scalar_t
*
params_grad_buf
=
input_buf
+
1
,
// [N] ... but element [-1] will have deriv of scale.
*
y_vals_grad_buf
=
output_grad_buf
;
// [N]
*
y_vals_grad_buf
=
output_grad_buf
;
// [N]
if
(
threadIdx
.
x
<
N
)
{
if
(
threadIdx
.
x
<
N
)
{
// Restore the indexing offset of 1 in params_grad_buf (versus
// params_buf
params_grad_buf
[
threadIdx
.
x
]
=
this_param_grad
;
params_grad_buf
[
threadIdx
.
x
]
=
this_param_grad
;
y_vals_grad_buf
[
threadIdx
.
x
]
=
this_y_vals_grad
;
y_vals_grad_buf
[
threadIdx
.
x
]
=
this_y_vals_grad
;
}
}
__syncthreads
();
// other threads are about to read params_grad_buf and
// y_vals_grad_buf.
// This next block does backprop relating to `y_vals`. Comparing with the CPU
// This next block does backprop relating to `y_vals`. Comparing with the CPU
// version (call this the "reference code") is the best way to understand this
// version (call this the "reference code") is the best way to understand this
...
@@ -479,7 +500,7 @@ void learned_nonlin_backward_kernel(
...
@@ -479,7 +500,7 @@ void learned_nonlin_backward_kernel(
// the deriv of the log scale.
// the deriv of the log scale.
scalar_t
l_grad
;
scalar_t
l_grad
;
if
(
threadIdx
.
x
==
64
)
{
if
(
threadIdx
.
x
==
0
)
{
// Now do the backprop for the loop above where we set y_vals_a. This could
// Now do the backprop for the loop above where we set y_vals_a. This could
// be further optimized to replace the loop with a raking, but I doubt this
// be further optimized to replace the loop with a raking, but I doubt this
// will have a huge effect on the runtime since K will be fairly small,
// will have a huge effect on the runtime since K will be fairly small,
...
@@ -499,9 +520,11 @@ void learned_nonlin_backward_kernel(
...
@@ -499,9 +520,11 @@ void learned_nonlin_backward_kernel(
scale_grad
+=
pos_scaled_param_grad
*
params_buf
[
K
+
i
];
scale_grad
+=
pos_scaled_param_grad
*
params_buf
[
K
+
i
];
}
}
// Backprop for: scale = exp(l), where l = params[c][0].
// Backprop for: scale = exp(l), where l = params[c][0].
params_grad_buf
[
-
1
]
=
scale
*
scale_grad
;
l_grad
=
scale
*
scale_grad
;
}
else
if
(
threadIdx
.
x
==
0
)
{
}
else
if
(
threadIdx
.
x
==
64
)
{
// Now do the backprop for the loop above where we set y_vals.
// Now do the backprop for the loop above where we set y_vals.
// Make this one threadIdx.x == 0 so it's possibly quicker to test
//
scalar_t
scale
=
params_buf
[
-
2
],
scalar_t
scale
=
params_buf
[
-
2
],
scale_grad
=
0.0
,
scale_grad
=
0.0
,
sum_negative_grad
=
0.0
;
sum_negative_grad
=
0.0
;
...
@@ -516,14 +539,17 @@ void learned_nonlin_backward_kernel(
...
@@ -516,14 +539,17 @@ void learned_nonlin_backward_kernel(
params_grad_buf
[
K
-
i
-
1
]
+=
neg_scaled_param_grad
*
scale
;
params_grad_buf
[
K
-
i
-
1
]
+=
neg_scaled_param_grad
*
scale
;
scale_grad
+=
neg_scaled_param_grad
*
params_buf
[
K
-
i
-
1
];
scale_grad
+=
neg_scaled_param_grad
*
params_buf
[
K
-
i
-
1
];
}
}
l_grad
=
scale
*
scale_grad
;
params_grad_buf
[
-
1
]
=
scale
*
scale_grad
;
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
params_grad_buf
[
-
1
]
+=
l_grad
;
// contribution to l grad from the "negative" branch
params_grad_buf
[
-
1
]
+=
l_grad
;
// contribution to l grad from the "negative" branch
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
<=
N
)
if
(
threadIdx
.
x
<=
N
)
{
params_grad
[
blockIdx
.
y
][
c
][
threadIdx
.
x
]
=
params_grad_buf
[
threadIdx
.
x
-
1
];
params_grad
[
blockIdx
.
y
][
c
][
threadIdx
.
x
]
=
params_grad_buf
[
threadIdx
.
x
-
1
];
}
}
}
...
@@ -623,7 +649,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
...
@@ -623,7 +649,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
TORCH_CHECK
(
output_grad
.
device
().
is_cuda
(),
"output_grad must be a CUDA tensor"
);
TORCH_CHECK
(
output_grad
.
device
().
is_cuda
(),
"output_grad must be a CUDA tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cuda
(),
"Params must be a CUDA tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cuda
(),
"Params must be a CUDA tensor"
);
const
int
B
=
input
.
size
(
0
),
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
T
=
input
.
size
(
2
),
...
@@ -631,6 +656,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
...
@@ -631,6 +656,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
TORCH_CHECK
(
N
>=
4
,
"This backward code requires N >= 4"
);
TORCH_CHECK
(
N
>=
4
,
"This backward code requires N >= 4"
);
TORCH_CHECK
(
N
<=
16
,
"This backward code currently requires N <= 16"
);
TORCH_CHECK
(
N
<=
16
,
"This backward code currently requires N <= 16"
);
TORCH_CHECK
((
N
&
(
N
-
1
))
==
0
,
"N must be a power of 2"
)
auto
scalar_t
=
input
.
scalar_type
();
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
...
@@ -663,7 +689,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
...
@@ -663,7 +689,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
int
shared_mem_numel
=
2
*
N
+
3
;
int
shared_mem_numel
=
2
*
N
+
3
;
if
(
false
)
if
(
true
)
std
::
cout
<<
"C,B,T,N = "
<<
C
<<
","
<<
B
<<
","
<<
T
<<
","
<<
N
std
::
cout
<<
"C,B,T,N = "
<<
C
<<
","
<<
B
<<
","
<<
T
<<
","
<<
N
<<
", images_per_thread_block = "
<<
images_per_thread_block
<<
", images_per_thread_block = "
<<
images_per_thread_block
<<
", grid_dim_y = "
<<
grid_dim_y
<<
", grid_dim_y = "
<<
grid_dim_y
...
@@ -673,8 +701,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
...
@@ -673,8 +701,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
images_per_thread_block
==
1
,
images_per_thread_block
==
1
,
"Code error"
);
"Code error"
);
TORCH_CHECK
(
THREADS_PER_BLOCK
/
images_per_thread_block
>=
N
);
torch
::
Tensor
params_grad
=
torch
::
empty
({
grid_dim_y
,
C
,
N
+
1
},
opts
);
torch
::
Tensor
params_grad
=
torch
::
zeros
({
grid_dim_y
,
C
,
N
+
1
},
opts
);
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
...
...
torch_learned_nonlin/learned_nonlin_test.py
View file @
c3e61bea
...
@@ -13,7 +13,7 @@ def test_learned_nonlin_basic():
...
@@ -13,7 +13,7 @@ def test_learned_nonlin_basic():
K
=
4
K
=
4
N
=
K
*
2
N
=
K
*
2
params
=
torch
.
arange
(
N
+
1
,
dtype
=
dtype
).
unsqueeze
(
0
)
+
torch
.
arange
(
C
,
dtype
=
dtype
).
unsqueeze
(
1
)
params
=
torch
.
arange
(
N
+
1
,
dtype
=
dtype
).
unsqueeze
(
0
)
+
torch
.
arange
(
C
,
dtype
=
dtype
).
unsqueeze
(
1
)
-
3
x
.
requires_grad
=
True
x
.
requires_grad
=
True
params
.
requires_grad
=
True
params
.
requires_grad
=
True
print
(
"x = "
,
x
)
print
(
"x = "
,
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment