Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
74897fd5
Commit
74897fd5
authored
Jul 15, 2021
by
Daniel Povey
Browse files
Test sometimes failing, think it's an older problem.
parent
06e369c9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
439 additions
and
503 deletions
+439
-503
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+438
-502
torch_learned_nonlin/learned_nonlin_test.py
torch_learned_nonlin/learned_nonlin_test.py
+1
-1
No files found.
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
74897fd5
...
@@ -22,9 +22,9 @@
...
@@ -22,9 +22,9 @@
blockDim.x, to be used as a temporary within this function.
blockDim.x, to be used as a temporary within this function.
val: The value to be summed
val: The value to be summed
Return:
Return:
Threads where
blockDim
.x % threads_per_tile == 0 will return the sum:
Threads where
threadIdx
.x % threads_per_tile == 0 will return the sum:
\sum_{i=0}^{threads_per_tile-1} [val in thread threadIdx.x + i]
\sum_{i=0}^{threads_per_tile-1} [val in thread threadIdx.x + i]
R
eturn value in other threads is undefined.
The r
eturn value in other threads is undefined.
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__forceinline__
__device__
scalar_t
tiled_warp_reduce_sum
(
int
threads_per_tile
,
__forceinline__
__device__
scalar_t
tiled_warp_reduce_sum
(
int
threads_per_tile
,
...
@@ -43,8 +43,9 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...
@@ -43,8 +43,9 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
/*
/*
Forward of learned_nonlin. Each thread group handles a single channel (equal
Forward of learned_nonlin. Each thread group handles a single channel (channel
to blockIdx.x); the gridDim is (C, nb) where 1 <= nb <= B (nb relates to the batch).
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
Template args:
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
scalar_t: the floating-point type, e.g. float, double, maybe half.
...
@@ -71,7 +72,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...
@@ -71,7 +72,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
.. this is used for a small optimization.
.. this is used for a small optimization.
This kernel is allocated with `extern_buf` containing enough memory
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N values of type scalar_t.
to store 2*N
+ 3
values of type scalar_t.
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
...
@@ -80,9 +81,10 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...
@@ -80,9 +81,10 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
1 <= gridDim.y <= B, where B is the number of blocks
1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z == 1
gridDim.z == 1
When we invoke this kernel, we'll invoke it as:
When we invoke this kernel, we'll invoke it as:
learned_nonlin_
forward
<<<gridDim, blockDim, bytesShared, stream>>>
learned_nonlin_
kernel
<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
bytesShared = sizeof(shared_t) * (2N + 3)
We also require N + 1 <= THREADS_PER_BLOCK.
*/
*/
extern
__shared__
int
extern_buf
[];
extern
__shared__
int
extern_buf
[];
...
@@ -98,31 +100,33 @@ void learned_nonlin_kernel(
...
@@ -98,31 +100,33 @@ void learned_nonlin_kernel(
C
=
input
.
size
(
1
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
// Note: N and K are powers
f
o 2, with K >= 1.
K
=
N
/
2
;
// Note: N and K are powers o
f
2, with K >= 1.
const
int
c
=
blockIdx
.
x
;
// c is channel index
const
int
c
=
blockIdx
.
x
;
// c is channel index
scalar_t
*
y_vals
=
(
scalar_t
*
)
extern_buf
,
// [N], actually there are
two
scalar_t
*
y_vals
=
(
scalar_t
*
)
extern_buf
,
// [N], actually there are
3
// spaces between here and
// spaces between here and
// `params_buf` for storing scale
// `params_buf` for storing scale
// and inv_scale.
// and inv_scale
and l == params[c][0]
.
*
params_buf
=
(
scalar_t
*
)
y_vals
+
3
+
N
;
// [N]. Caution: contains params[c][1] through params[c][N].
*
params_buf
=
(
scalar_t
*
)
y_vals
+
3
+
N
;
// [N]. Caution: contains params[c][1] through params[c][N].
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// Load parameters
// Load parameters
for
(
int
n
=
threadIdx
.
x
;
n
<=
N
;
n
+=
THREADS_PER_BLOCK
)
{
if
(
threadIdx
.
x
<=
N
)
params_buf
[
n
-
1
]
=
params
[
c
][
n
];
params_buf
[
threadIdx
.
x
-
1
]
=
params
[
c
][
threadIdx
.
x
];
}
__syncthreads
();
__syncthreads
();
// The easiest way to understand this code is to compare it with the CPU code
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
// in learned_nonlin_cpu.cpp.
if
((((
int
)
threadIdx
.
x
&
~
(
int
)
32
))
==
0
)
{
// TODO: replace this with easier-to-understand code.
// threadIdx.x == 0 or 32. These are in separate warps so we can
if
((((
int
)
threadIdx
.
x
&
~
(
int
)
64
))
==
0
)
{
// allow them to do separate jobs. This code takes linear time in K which
// threadIdx.x == 0 or 64 (we choose 64 because it's >= the max known warp
// is not at all ideal and could be improved if K is largish, but it shouldn't
// size). These are in separate warps so we can allow them to do separate
// dominate the total time taken if we are processing a lot of data;
// jobs. This code takes linear time in K which is not at all ideal and
// and anyway, we doubt that K will be need to be more than 4 or 8 or so,
// could be improved if K is largish, but it shouldn't dominate the total
// so the potential savings are quite small.
// time taken if we are processing a lot of data; and anyway, we doubt that
// K will be need to be more than 4 or 8 or so, so the potential savings are
// quite small.
scalar_t
scale
=
exp
(
params_buf
[
-
1
]),
scalar_t
scale
=
exp
(
params_buf
[
-
1
]),
inv_scale
=
1.0
/
scale
;
inv_scale
=
1.0
/
scale
;
params_buf
[
-
2
]
=
scale
;
// both threads write these but it's OK, it's the
params_buf
[
-
2
]
=
scale
;
// both threads write these but it's OK, it's the
...
@@ -137,7 +141,7 @@ void learned_nonlin_kernel(
...
@@ -137,7 +141,7 @@ void learned_nonlin_kernel(
if
(
threadIdx
.
x
==
0
)
{
// sum_positive
if
(
threadIdx
.
x
==
0
)
{
// sum_positive
sign
=
1
;
sign
=
1
;
Koffset
=
K
;
Koffset
=
K
;
}
else
{
// threadIdx.x ==
32
. sum_negative.
}
else
{
// threadIdx.x ==
64
. sum_negative.
scale
*=
-
1
;
// this is a local variable..
scale
*=
-
1
;
// this is a local variable..
sign
=
-
1
;
sign
=
-
1
;
Koffset
=
K
-
1
;
Koffset
=
K
-
1
;
...
@@ -155,11 +159,11 @@ void learned_nonlin_kernel(
...
@@ -155,11 +159,11 @@ void learned_nonlin_kernel(
scalar_t
inv_scale
=
params_buf
[
-
3
];
scalar_t
inv_scale
=
params_buf
[
-
3
];
int
T_inc
=
THREADS_PER_BLOCK
/
images_per_thread_block
,
int
T_inc
=
THREADS_PER_BLOCK
/
images_per_thread_block
,
image
_offset
=
threadIdx
.
x
/
T_inc
,
b
_offset
=
threadIdx
.
x
/
T_inc
,
// offset within batch
t_start
=
threadIdx
.
x
%
T_inc
;
t_start
=
threadIdx
.
x
%
T_inc
;
for
(
int
b
=
blockIdx
.
y
*
images_per_thread_block
+
image
_offset
;
for
(
int
b
=
blockIdx
.
y
*
images_per_thread_block
+
b
_offset
;
b
<
B
;
b
<
B
;
b
+=
gridDim
.
y
*
images_per_thread_block
)
{
b
+=
gridDim
.
y
*
images_per_thread_block
)
{
// We do "t += THREADS_PER_BLOCK" instead of t += (THREADS_PER_BLOCK /
// We do "t += THREADS_PER_BLOCK" instead of t += (THREADS_PER_BLOCK /
// images_per_thread_block) as a small optimization because the only case we
// images_per_thread_block) as a small optimization because the only case we
// really need to loop is when images_per_thread_block == 1:a we only let
// really need to loop is when images_per_thread_block == 1:a we only let
...
@@ -172,367 +176,375 @@ void learned_nonlin_kernel(
...
@@ -172,367 +176,375 @@ void learned_nonlin_kernel(
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
// C++ rounds toward zero.
// C++ rounds toward zero.
int
n
=
(
int
)
x_trunc
;
int
n
=
(
int
)
x_trunc
;
// OK, at this point, 0 <= min < 2*K.
// OK, at this point, 0 <= min < N.
scalar_t
y
=
(
x
-
n
)
*
params_buf
[
n
]
+
y_vals
[
n
];
output
[
b
][
c
][
t
]
=
(
x
-
n
)
*
params_buf
[
n
]
+
y_vals
[
n
];
output
[
b
][
c
][
t
]
=
y
;
}
}
}
}
}
}
/*
/*
Backward of learned_nonlin. Each thread group handles a single channel (equal
Summing reduction within a one-dimensional thread block, but with a
to blockIdx.x), and loops over patches of the output and over the image n
stride of N, so that we separately sum up the values of all threads with
within the batch (different thread groups may be responsible for different
threadIdx.x % N == 0, with threadIdx.x % N == 1, and so on. At the end,
subsets of patches and/or images, see docs of gridDim below).
threads with 0 <= threadIdx.x < N contain the sums.
If you want to understand this code, you should first understand the forward
So this is like tiled summing reduction except that the tiles are
code. Here are some points to understand how this works:
interspersed with each other.
First, understand the difference between the patch of size patchH by
patchW, which is the basic patch size that is related to the blockDim.x,
Args:
and the padded patch size ppatchH and ppatchW, where:
N: The number we sum modulo (must be a power of 2 with
ppatchH = patchH + kH - 1
1 <= N <= blockDim.x), i.e. all threads with
ppatchW = patchW + kW - 1.
threadIdx.x % N == n for some 0 <= n < N have `val` summed.
buf: Pointer to the start of a __shared__ buffer of size
In the forward pass, we dealt with a patch of output and a padded patch of
blockDim.x, to be used as a temporary within this function.
input. In this backward-pass code, when computing the `grad_input` we deal
val: The value to be summed
with a patch of input and a padded patch of output (this ensures that
Return:
different thread-blocks write to distinct patches of `grad_input`). But this
Threads where threadIdx.x < N will return the sums (over the threads with
approach is not sufficient to update `grad_pos_add` and `grad_pos_mul`,
the same value of threadIdx.x % N);
because it's possible for elements of the zero-padding of `input` to
the return value in other threads is undefined.
contribute to `grad_pos_add` and `grad_pos_mul`. So when computing the
*/
gradients for those quantities, we actually use a padded input patch and an
template
<
typename
scalar_t
>
un-padded output patch. This requires that we load into shared memory the
__forceinline__
__device__
scalar_t
strided_reduce_sum
(
int
N
,
padded versions of both input and grad_output.
__volatile__
scalar_t
*
buf
,
scalar_t
val
)
{
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
for
(
int
i
=
blockDim
.
x
/
2
;
i
>=
N
;
i
/=
2
)
{
buf
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
if
(
threadIdx
.
x
<
i
)
val
+=
buf
[
threadIdx
.
x
+
i
];
}
return
val
;
// Only threads with threadIdx.x < N will return the full sums of
// their groups.
}
/*
Backward of learned_nonlin. Each thread group handles a single channel (channel
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
Template args:
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
scalar_t: the floating-point type, e.g. float, double, maybe half.
Args:
Args:
input [in]: input image, shape (N, 2*C, H, W)
input: input image, shape (B, C, T) where B is batch size, C is
pos_add [in]: positional encoding, additive part, shape (C, kH, kW)
the number of channels and T is the time axis. (For more-than-1d
pos_mul [in]: positional encoding, multiplicative part, shape (C, kH, kW)
convolution setups, T would really be more than 1 axis, reshaped).
grad_output [in]: the gradient w.r.t. the output of the forward pass, shape (N, C, H, W)
params: of shape (C, N+1) where N is the number of linear regions in the
grad_input [out]: the gradient w.r.t. the input, of shape N, 2*C, H, W
piecewise linear function; params[c][0] is l which is
grad_pos_add [out]: the gradient w.r.t. pos_add, indexed [block][c][kh][kw],
a log scale parameter that dictates how far apart
of shape num_blocks, C, kH, kW,
the discontinuities in the piecewise linear function are,
where `block` is an index we'll later sum over, that corresponds to
and params[c][n+1] for 0 <= n < N are the derivatives
the identity of the thread-block (except, not including the channel
of the linear parts of the piecewise linear function.
dimension == gridDim.x). So, block == blockIdx.z * gridDim.y + blockIdx.y,
The discontinuities of the function are at:
and num_blocks == gridDim.y * gridDim.z.
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
grad_pos_mul [out]: the gradient w.r.t. pos_mul, like grad_pos_add above.
output: The transformed input, shape (B , C, T)
patchH: the height of the patch size this kernel operates on (prior to padding)
images_per_thread_block: The number of images processed by each thread
patchW: the width of the patch size this kernel operates on (prior to padding)
block. The calling code must guarantee that this is a power
threads_per_pixel: the number of threads assigned to compute each pixel
of 2, and that EITHER:
of grad_input. Require patchH * patchW * threads_per_pixel <= blockDim.x
(THREADS_PER_BLOCK / images_per_thread_block >= T AND
and threads_per_pixel must be a power of 2 in the interval [1,32].
THREADS_PER_BLOCK / images_per_thread_block >= N),
threads_per_kernel_pos: the number of threads assigned to compute each kernel
OR
position of grad_pos_add and grad_pos_mul.
images_per_thread_block == 1
Require kH * kW * threads_per_kernel_pos <= blockDim.x,
.. this is used for a small optimization.
and threads_per_kernel_pos must be a power of 2 in the interval [1,32].
This requires that kH * kW must not be greater than 1024.
ALSO,
Note: kH and kW must both be odd so that it's clear how to zero-pad.
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
The thread-block should have one dimension (x); see docs for threads_per_pixel
and threads_per_kernel_pos for requirements on blockDim.x. Also, blockDim.x
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
must be an exact multiple of 64, so we can divide the threads by 2 and they
will be in different warps.
The requirements on the grid dimension are:
The requirements on the grid dimension are:
gridDim.x == num-channels C (required)
gridDim.x == num-channels C (required)
gridDim.y <=
num-patches per image (recommended)
1 <=
gridDim.y <=
B, where B is the number of blocks
gridDim.z
<
=
batch-size N (recommended)
gridDim.z
=
=
1
When we invoke this kernel, we'll invoke it as:
When we invoke this kernel, we'll invoke it as:
learned_nonlin_
forward
<<<gridDim, blockDim, bytesShared, stream>>>
learned_nonlin_
backward_kernel
<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
We also require that N <= THREADS_PER_BLOCK (for best performance,
N should be quite small, like no larger than 8 or so).
We also require 4 <= N <= 16 for this code!
bytesShared = sizeof(shared_t) * numel, where
numel = 4 * (kH * kW) + 3 * (ppatchH * ppatchW) + blockDim.x
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
learned_nonlin_kernel_backward
(
void
learned_nonlin_backward_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
input
,
// N, 2*C, H, W
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_add
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_mul
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output_grad
,
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
grad_output
,
// N, C, H, W
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input_grad
,
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
grad_input
,
// N, 2*C, H, W
// params_grad is of dim (gridDim.y, C, N + 1), we'll sum over dim 0.
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
grad_pos_add
,
// block, C, kH, kW, see above for `block`
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
params_grad
,
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
grad_pos_mul
,
// block, C, kH, kW, see above for `block`
int
images_per_thread_block
)
{
// B, C, T
int
patchH
,
// non-padded patch height
int
patchW
,
// non-padded patch width
const
int
B
=
input
.
size
(
0
),
int
threads_per_pixel
,
C
=
input
.
size
(
1
),
int
threads_per_kernel_pos
)
{
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
const
int
H
=
input
.
size
(
2
),
K
=
N
/
2
;
// Note: N and K are powers fo 2, with K >= 1.
W
=
input
.
size
(
3
),
kH
=
pos_add
.
size
(
1
),
const
int
c
=
blockIdx
.
x
;
// c is channel index
kW
=
pos_add
.
size
(
2
),
npatchH
=
(
H
+
patchH
-
1
)
/
patchH
,
// num patches in vertical dim
scalar_t
*
y_vals
=
(
scalar_t
*
)
extern_buf
,
// [N], actually there are three
npatchW
=
(
W
+
patchW
-
1
)
/
patchW
,
// num patches in horizontal dim
// spaces between here and
npatch
=
npatchH
*
npatchW
;
// total number of patches per image
// `params_buf` for storing scale
// and inv_scale and l == params[c][0].
// Channel index.
*
params_buf
=
(
scalar_t
*
)
y_vals
+
3
+
N
;
// [N]. Caution: contains params[c][1] through params[c][N].
const
int
c
=
blockIdx
.
x
;
// params_buf[-1] contains params[c][0] == log of scale;
// We don't need to check the range of `c` because we set gridDim.x to the
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// exact number of channels.
scalar_t
x_residual_buf
[
THREADS_PER_BLOCK
];
// x_residual, with 0 <=
const
int
ppatchH
=
patchH
+
kH
-
1
,
// ppatchH is the padded patch height.
// x_residual < 1 for interior
ppatchW
=
patchW
+
kW
-
1
,
// ppatchW is the padded patch width
// regions, is the residual part
patch_size
=
patchH
*
patchW
,
// un-padded patch size
// of the scaled input, after
ppatch_size
=
ppatchH
*
ppatchW
;
// padded patch size
// subtracting the integer part.
scalar_t
output_grad_buf
[
THREADS_PER_BLOCK
];
// `extern_buf` is general-purpose shared memory, which we'll divide between
char
n_buf
[
THREADS_PER_BLOCK
];
// for each input in `input_buf`, this stores
// various buffers.
// the integer value 0 <= n < N which
// determines which piece of the piecewise
// these are pointers to __shared__ memory; the compiler should
// linear function we are in.
// be able to figure this out.
scalar_t
// this_params_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
*
pos_add_buf
=
(
scalar_t
*
)
extern_buf
,
// pos_add positional-encoding / kernel parameters,
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// indexed [kh*kW + kw] where kh and kw are vertical
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// and horizontal positions in the kernel.
// linear interval.
*
pos_mul_buf
=
pos_add_buf
+
(
kH
*
kW
),
// pos_mul positional-encoding / kernel parameters,
scalar_t
this_params_grad
=
0.0
,
// indexed [kh*kW + kw] where kh and kw are vertical
this_y_vals_grad
=
0.0
;
// and horizontal positions in the kernel.
*
src_img_buf
=
pos_mul_buf
+
(
kH
*
kW
),
// version of input image that relates to source position,
// Load parameters
// of size [ppatch_size], indexed [h*ppatchW + w],
if
(
threadIdx
.
x
<=
N
)
// where the 'h' and 'w' indexes are into the zero-padded input
params_buf
[
threadIdx
.
x
-
1
]
=
params
[
c
][
threadIdx
.
x
];
// image.
*
dest_img_buf
=
src_img_buf
+
ppatch_size
,
// version of input image that relates to destinatioon position
*
grad_output_buf
=
dest_img_buf
+
ppatch_size
,
// output gradient for padded patch, indexed [h*ppatchW + w]
*
grad_pos_add_buf
=
grad_output_buf
+
ppatch_size
,
// total grad for pos_add for this thread block, indexed [kh*kW + kw]
*
grad_pos_mul_buf
=
grad_pos_add_buf
+
(
kH
*
kW
),
// total grad for pos_mul for this thread block, indexed [kh*kW + kw]
*
reduce_buf
=
grad_pos_mul_buf
+
(
kH
*
kW
);
// buffer for reduction over threads, size == blockDim.x
// pos_in_patch will be interpreted as h_in_patch * patchW + w_in_patch.
int
pos_in_patch
=
threadIdx
.
x
/
threads_per_pixel
;
// Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// in pos_add_buf and pos_mul_buf; zero the corresponding gradient buffers.
// We know that blockDim.x >= kH * kW, see threads_per_kernel_pos.
for
(
int
i
=
threadIdx
.
x
%
(
blockDim
.
x
/
2
);
i
<
kH
*
kW
;
i
+=
(
blockDim
.
x
/
2
))
{
int
kh
=
i
/
kW
,
kw
=
i
%
kW
;
if
(
threadIdx
.
x
<
blockDim
.
x
/
2
)
{
// First half of threads take care of pos_add..
pos_add_buf
[
i
]
=
pos_add
[
c
][
kh
][
kw
];
grad_pos_add_buf
[
i
]
=
0.0
;
}
else
{
// Second half take care of pos_mul... there is no warp divergence
// because we make sure blockDim.x is a multiple of 64.
pos_mul_buf
[
i
]
=
pos_mul
[
c
][
kh
][
kw
];
grad_pos_mul_buf
[
i
]
=
0.0
;
}
}
// n is the index within the batch of images. Loop to make sure we cover all
// images in the batch. input.size(0) is the batch size N. All threads in
// the thread-block loop the same number of times.
for
(
int
n
=
blockIdx
.
z
;
n
<
input
.
size
(
0
);
n
+=
gridDim
.
z
)
{
// Loop over the patch within the output image. All threads in the
// thread-block loop the same number of times.
for
(
int
patch_idx
=
blockIdx
.
y
;
patch_idx
<
npatch
;
patch_idx
+=
gridDim
.
y
)
{
// (patch_h_offset, patch_w_offset) are the (vertical, horizontal) indexes
// of the lowest-numbered pixel in the *un-padded* patch that this thread
// block is responsible for. (We'll actualy be loading the padded patches
// into memory, so be careful).
int
patch_h_offset
=
(
patch_idx
/
npatchW
)
*
patchH
,
patch_w_offset
=
(
patch_idx
%
npatchW
)
*
patchW
;
// This __syncthreads() is only necessary if we have already looped at
// least once over n or patch_idx: it's in case other threads are still
// using the `src_img_buf` or `dst_img_buf` buffers for a previous patch.
__syncthreads
();
__syncthreads
();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
// Load the 'src' and 'dest' versions of the padded patch into
// This next block computes `y_vals`.
// shared-memory buffers, and also the output gradient.
if
((((
int
)
threadIdx
.
x
&
~
(
int
)
32
))
==
0
)
{
for
(
int
i
=
threadIdx
.
x
%
(
blockDim
.
x
/
2
);
// threadIdx.x == 0 or 32. These are in separate warps so we can
i
<
ppatch_size
;
i
+=
(
blockDim
.
x
/
2
))
{
// allow them to do separate jobs. This code takes linear time in K which
int
h_in_ppatch
=
i
/
ppatchW
,
// is not at all ideal and could be improved if K is largish, but it shouldn't
w_in_ppatch
=
i
%
ppatchW
;
// dominate the total time taken if we are processing a lot of data;
int
h
=
patch_h_offset
+
h_in_ppatch
-
(
kH
/
2
),
// kH / 2 is offset due to padding
// and anyway, we doubt that K will be need to be more than 4 or 8 or so,
w
=
patch_w_offset
+
w_in_ppatch
-
(
kW
/
2
);
// so the potential savings are quite small.
scalar_t
scale
=
exp
(
params_buf
[
-
1
]),
if
(
threadIdx
.
x
<
blockDim
.
x
/
2
)
{
// The first half of the threads of the block
inv_scale
=
1.0
/
scale
;
// load `input`
params_buf
[
-
2
]
=
scale
;
// both threads write these but it's OK, it's the
scalar_t
src_val
=
scalar_t
(
0
),
// same value.
dest_val
=
scalar_t
(
0
);
params_buf
[
-
3
]
=
inv_scale
;
if
((
unsigned
int
)
h
<
(
unsigned
int
)
H
&&
// h >= 0 && h < H
int
sign
,
(
unsigned
int
)
w
<
(
unsigned
int
)
W
)
{
// w >= 0 && w < W
Koffset
;
// Koffset == K for threads handling sum_positive and K - 1
int
C
=
grad_output
.
size
(
1
);
// for threads handling sum_negative, see
src_val
=
input
[
n
][
c
][
h
][
w
];
// learned_nonlin_cpu.cpp for reference code. This would be K
dest_val
=
input
[
n
][
c
+
C
][
h
][
w
];
// + 1 and K respectively, except our params_buf has its index
// shifted by one versus params.
if
(
threadIdx
.
x
==
0
)
{
// sum_positive
sign
=
1
;
Koffset
=
K
;
}
else
{
// threadIdx.x == 32. sum_negative.
scale
*=
-
1
;
// this is a local variable..
sign
=
-
1
;
Koffset
=
K
-
1
;
}
}
src_img_buf
[
i
]
=
src_val
;
scalar_t
sum
=
0.0
;
dest_img_buf
[
i
]
=
dest_val
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
}
else
{
// second half of threads load `grad_output`. We require
int
isign
=
i
*
sign
;
// blockDim.x be an even multiple of the warp size, so there
y_vals
[
K
+
isign
]
=
sum
*
scale
;
// is no warp divergence here.
sum
+=
params_buf
[
Koffset
+
isign
];
scalar_t
grad_output_val
=
scalar_t
(
0
);
if
((
unsigned
int
)
h
<
(
unsigned
int
)
H
&&
(
unsigned
int
)
w
<
(
unsigned
int
)
W
)
grad_output_val
=
grad_output
[
n
][
c
][
h
][
w
];
grad_output_buf
[
i
]
=
grad_output_val
;
}
}
if
(
threadIdx
.
x
!=
0
)
// sum_negative
y_vals
[
0
]
=
sum
*
scale
;
}
}
// make sure all threads haave written to `src_img_buf`, `dest_img_buf` and
// `grad_output_buf`.
__syncthreads
();
__syncthreads
();
scalar_t
inv_scale
=
params_buf
[
-
3
];
scalar_t
grad_input_src_sum
=
0.0
,
// grad for channel c, for our pixel
int
T_inc
=
THREADS_PER_BLOCK
/
images_per_thread_block
,
// of `input` (contribution of this
b_offset
=
threadIdx
.
x
/
T_inc
;
// offset within batch
// thread)
grad_input_dest_sum
=
0.0
;
// grad for channel c + C, for our pixel
for
(
int
b
=
blockIdx
.
y
*
images_per_thread_block
+
b_offset
;
b
<
B
;
// of `input` (contribution of this thread)
b
+=
gridDim
.
y
*
images_per_thread_block
)
{
if
(
pos_in_patch
<
patch_size
)
{
// This block computes `grad_input_src_sum` and `grad_input_dest_sum`
// The following will loop just once if images_per_thread_block > 1. If
// The num-threads for the backward kernel may not be an exact multiple
// images_per_thread_block == 1 and T > THREADS_PER_BLOCK, we will loop
// of patch_size, wo we need the if-guard.
// multiple times. We want to keep all threads active so that output_grad
// will be set to zero for excess threads, and thus won't contribute to
int
h_in_patch
=
pos_in_patch
/
patchW
,
// this_params_grad or this_y_vals_grad.
w_in_patch
=
pos_in_patch
%
patchW
,
for
(
int
t_offset
=
0
;
t_offset
<
T
;
t_offset
+=
THREADS_PER_BLOCK
)
{
h_in_ppatch
=
h_in_patch
+
kH
/
2
,
int
t
=
threadIdx
.
x
%
T_inc
+
t_offset
;
w_in_ppatch
=
w_in_patch
+
kW
/
2
,
scalar_t
this_output_grad
=
0.0
,
x
=
0.0
;
pos_in_ppatch
=
h_in_ppatch
*
ppatchW
+
w_in_ppatch
;
if
(
t
<
T
)
this_output_grad
=
output_grad
[
b
][
c
][
t
];
// this_dest_val is the `destination` version of our current pixel; this
// is an input. It gets added to each src pixel, prior to the relu, in
// The reason we use t % T here rather than only invoking this in some
// the loop below.
// threads, is so that the un-needed threads will have a similar
// this_src_val is the `src` version of our current pixel; it contributes
// distribution over 'n' to the needed threads, which will hopefully avoid
// to the outputs of other pixels.
// excessive work for some particular 'n' value if too many x values had
scalar_t
this_dest_val
=
dest_img_buf
[
pos_in_ppatch
],
// the same 'n'. It might be better to set n to an invalid value for
this_src_val
=
src_img_buf
[
pos_in_ppatch
];
// out-of-range threads, but as it is, if we are to properly handle
// N==16 we don't have enough bits available in `src_indexes` to do this.
for
(
int
pos_in_kernel
=
threadIdx
.
x
%
threads_per_pixel
;
x
=
input
[
b
][
c
][
t
%
T
]
*
inv_scale
+
K
;
pos_in_kernel
<
(
kH
*
kW
);
pos_in_kernel
+=
threads_per_pixel
)
{
output_grad_buf
[
threadIdx
.
x
]
=
this_output_grad
;
int
h_in_kernel
=
pos_in_kernel
/
kW
,
scalar_t
x_trunc
=
x
;
w_in_kernel
=
pos_in_kernel
%
kW
;
if
(
x_trunc
<
0
)
x_trunc
=
0
;
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
// This is actually more like cross-correlation, as we don't have a
// C++ rounds toward zero.
// negative sign on the h and w indexes in the kernel.
int
n
=
(
int
)
x_trunc
;
int
src_h_in_ppatch
=
h_in_patch
+
h_in_kernel
,
n_buf
[
threadIdx
.
x
]
=
(
char
)
n
;
src_w_in_ppatch
=
w_in_patch
+
w_in_kernel
;
int
src_pos_in_ppatch
=
src_h_in_ppatch
*
ppatchW
+
src_w_in_ppatch
;
scalar_t
x_residual
=
x
-
n
;
x_residual_buf
[
threadIdx
.
x
]
=
x_residual
;
scalar_t
src_val
=
src_img_buf
[
src_pos_in_ppatch
],
pos_add_val
=
pos_add_buf
[
pos_in_kernel
],
// OK, at this point, 0 <= min < N.
pos_mul_val
=
pos_mul_buf
[
pos_in_kernel
];
// The forward code did:
scalar_t
relu
=
(
src_val
+
this_dest_val
+
pos_add_val
);
// output[b][c][t] = (x - n) * params_buf[n] + y_vals[n];
if
(
relu
>=
0.0
)
{
scalar_t
this_grad_output
=
grad_output_buf
[
pos_in_ppatch
];
if
(
t
<
T
)
grad_input_dest_sum
+=
this_grad_output
*
pos_mul_val
;
input_grad
[
b
][
c
][
t
]
=
this_output_grad
*
params_buf
[
n
];
int
this_block_start
=
threadIdx
.
x
&
~
(
N
-
1
),
// == N * (threadIdx.x / N),
this_n
=
threadIdx
.
x
&
(
N
-
1
);
// == threadIdx.x % N.
// this_n is the n value that this thread accumulates gradients for;
// it is responsible for output_grads in the block of threads
// from this_block_start to this_block_start+N-1.
// SYNC POINT At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is considered not
// future-proof). Threads above have written to n_buf, and threads below
// will read from it; but we don't need to explicitly synchronize for now
// because the reads/writes are among threads in a group of N threads with
// (4 <= N <= 16); and 16 is less than the warp size which is 32 or 64.
// src_indexes will contain up to 16 16-bit numbers, stored starting in its
// least significant bits. It will store all the offsets within this
// block of N, where the 'n' value equals this_n.
uint64_t
src_indexes
=
0
;
// num_src is the number of numbers in `src_indexes`. We need to store a
// separate counter because zero is a valid index and if we are to support
// N == 16 we don't have bits to spare in src_indexes to store some kind
// of marker.
int
num_src
=
0
;
// This loop always does N statements, but they should be relatively fast
// ones since the computation per n value is minimal and there is little
// I/O. We are figuring out the subset of our block of N elements,
// which this particular thread value is responsible for (because they
// have n == this_n), and storing them in `src_indexes` and `num_src`.
for
(
int
i
=
0
;
i
<
N
;
i
+=
4
)
{
uint32_t
n_block_of_4
=
*
reinterpret_cast
<
uint32_t
*>
(
n_buf
+
this_block_start
+
i
);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
// CUDA is little endian
char
n
=
(
char
)(
n_block_of_4
>>
(
8
*
j
));
if
(
n
==
this_n
)
{
// We require that N <= 16, so 4 bits is enough to store src_idx.
src_indexes
=
(
src_indexes
<<
4
)
|
(
i
+
j
);
++
num_src
;
}
}
// To compute a contribution to "this_input_src_grad", we need to
// Note: if, for out-of-range threads, we had values not in [0..N-1] in
// consider the contribution to the destination pixel that it would
// n_buf they won't end up mattering even though they are read here,
// have contributed to with this same offset.
// because they won't equal this_n. For values 0 <= n < N originating
// We have to flip the offsets: instead of "+ h_in_kernel",
// in out-of-range threads, the value won't matter because the
// we use (kH - 1) - h_in_kernel,.
// corresponding value in output_grad_buf will be zero.
int
dest_h_in_ppatch
=
h_in_patch
+
(
kH
-
1
)
-
h_in_kernel
,
dest_w_in_ppatch
=
w_in_patch
+
(
kW
-
1
)
-
w_in_kernel
,
dest_pos_in_ppatch
=
dest_h_in_ppatch
*
ppatchW
+
dest_w_in_ppatch
;
scalar_t
dest_val
=
dest_img_buf
[
dest_pos_in_ppatch
];
relu
=
dest_val
+
this_src_val
+
pos_add_val
;
if
(
relu
>=
0.0
)
{
scalar_t
dest_grad_output
=
grad_output_buf
[
dest_pos_in_ppatch
];
grad_input_src_sum
+=
dest_grad_output
*
pos_mul_val
;
}
}
}
}
// While num_src could theoretically be as large as N, the hope is that no
// thread in any given warp actually loops that many times. Once all
// threads in the warp are finished looping, we can continue. It is OK
// for different warps to get out of sync here; we could be looping over a
// number of images, and the hope is that different warps will reach the
// end of the outer loop at around the same time because their variations
// in speed will average out.
for
(;
num_src
>
0
;
--
num_src
,
src_indexes
>>=
4
)
{
int
src_idx
=
src_indexes
&
0xF
,
src_thread
=
this_block_start
+
src_idx
;
scalar_t
output_grad
=
output_grad_buf
[
src_thread
],
x_residual
=
x_residual_buf
[
src_thread
];
// Backprop for: output = x_residual * params_buf[n] + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_params_grad
+=
output_grad
*
x_residual
;
this_y_vals_grad
+=
output_grad
;
}
}
// Aggregate `grad_input_src_sum` over threads, if needed; and write the
// result to `grad_input`.
// h and w are un-padded indexes into the entire image.
int
h
=
patch_h_offset
+
pos_in_patch
/
patchW
,
w
=
patch_w_offset
+
pos_in_patch
%
patchW
;
if
(
h
<
H
&&
w
<
W
)
{
grad_input_src_sum
=
tiled_warp_reduce_sum
(
threads_per_pixel
,
reduce_buf
,
grad_input_src_sum
);
grad_input_dest_sum
=
tiled_warp_reduce_sum
(
threads_per_pixel
,
reduce_buf
,
grad_input_dest_sum
);
if
(
threadIdx
.
x
%
threads_per_pixel
==
0
)
{
grad_input
[
n
][
c
][
h
][
w
]
=
grad_input_src_sum
;
int
C
=
grad_output
.
size
(
1
);
grad_input
[
n
][
c
+
C
][
h
][
w
]
=
grad_input_dest_sum
;
}
}
}
}
// OK, we are done computing grad_input for this patch. Now
__syncthreads
();
// sync threads because we are about to re-use
// we need to contribute the contributions to grad_pos_add_buf
// output_grad_buf for reduction.
// and grad_pos_mul_buf for this patch.
// 0 <= pos_in_kernel < (kH * kW).
this_params_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_params_grad
);
int
pos_in_kernel
=
threadIdx
.
x
/
threads_per_kernel_pos
;
this_y_vals_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_y_vals_grad
);
scalar_t
this_grad_pos_add
=
0.0
,
this_grad_pos_mul
=
0.0
;
__syncthreads
();
// sync threads because we are about to re-use
if
(
pos_in_kernel
<
(
kH
*
kW
))
{
// output_grad_buf.
int
kh
=
pos_in_kernel
/
kW
,
kw
=
pos_in_kernel
%
kW
;
// Re-use some buffers..
scalar_t
*
params_grad_buf
=
x_residual_buf
,
// [N]
// This group of (threads_per_kernel_pos) threads is responsible
*
y_vals_grad_buf
=
output_grad_buf
;
// [N]
// for position (kh, kw) in the kernel; we iterate over the patch
// (an un-padded patch of output).
if
(
threadIdx
.
x
<
N
)
{
scalar_t
pos_add_val
=
pos_add_buf
[
pos_in_kernel
],
// There is an offset of 1 between the 'n' values and
pos_mul_val
=
pos_mul_buf
[
pos_in_kernel
];
// the position in 'params'. To keep the backprop code similar to the CPU
// backprop code we restore that offset here, i.e. use the same layout
for
(
int
pos_in_patch
=
threadIdx
.
x
%
threads_per_kernel_pos
;
// as the params.
pos_in_patch
<
patch_size
;
pos_in_patch
+=
threads_per_kernel_pos
)
{
params_grad_buf
[
threadIdx
.
x
+
1
]
=
this_params_grad
;
// We are working out the contribution to the gradients for pos_add
y_vals_grad_buf
[
threadIdx
.
x
]
=
this_y_vals_grad
;
// and pos_mul; we let `pos_in_patch` correspond to the *output*
// position, and work out the input position based on gthe kernel position.
int
h_in_patch
=
pos_in_patch
/
patchW
,
w_in_patch
=
pos_in_patch
%
patchW
;
// pos_in_ppatch is the position in the padded patch corresponding to
// `pos_in_patch`.
int
pos_in_ppatch
=
(
h_in_patch
+
kH
/
2
)
*
ppatchW
+
(
w_in_patch
+
kW
/
2
);
scalar_t
dest_val
=
dest_img_buf
[
pos_in_ppatch
];
int
src_pos_in_ppatch
=
(
h_in_patch
+
kh
)
*
ppatchW
+
(
w_in_patch
+
kw
);
scalar_t
src_val
=
src_img_buf
[
src_pos_in_ppatch
];
scalar_t
relu
=
dest_val
+
src_val
+
pos_add_val
;
if
(
relu
>=
0.0
)
{
scalar_t
this_grad_output
=
grad_output_buf
[
pos_in_ppatch
];
this_grad_pos_add
+=
this_grad_output
*
pos_mul_val
;
this_grad_pos_mul
+=
this_grad_output
*
relu
;
}
}
}
this_grad_pos_add
=
tiled_warp_reduce_sum
(
threads_per_kernel_pos
,
reduce_buf
,
this_grad_pos_add
);
this_grad_pos_mul
=
tiled_warp_reduce_sum
(
// This next block does backprop relating to `y_vals`. Comparing with the CPU
threads_per_kernel_pos
,
reduce_buf
,
this_grad_pos_mul
);
// version (call this the "reference code") is the best way to understand this (this code is just a
if
(
threadIdx
.
x
%
threads_per_kernel_pos
==
0
)
{
// modification of that).
grad_pos_add_buf
[
pos_in_kernel
]
+=
this_grad_pos_add
;
{
grad_pos_mul_buf
[
pos_in_kernel
]
+=
this_grad_pos_mul
;
// Thread 0 is responsible for parts of the reference code that involve "sum_positive_grad";
// thread 64 is responsible for parts of the reference code that involve "sum_negative_grad";
scalar_t
scale_grad
=
0.0
,
scale
=
params_buf
[
-
2
];
if
(
threadIdx
.
x
==
0
)
{
scalar_t
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// This is like the CPU code but with an offset of 1 for 'params_buf'
// versus 'params_a'.
params_grad_buf
[
1
+
K
+
i
]
+=
sum_positive_grad
*
scale
;
scale_grad
+=
sum_positive_grad
*
params_buf
[
K
+
i
];
sum_positive_grad
+=
y_vals_grad_buf
[
K
+
i
];
}
}
params_grad_buf
[
0
]
+=
scale
*
scale_grad
;
}
else
if
(
threadIdx
.
x
==
64
)
{
scalar_t
sum_negative_grad
=
y_vals_grad_buf
[
0
];
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// This is like the CPU code but with an offset of 1 for 'params_buf'
// versus 'params_a'.
params_grad_buf
[
K
-
i
]
-=
sum_negative_grad
*
scale
;
scale_grad
-=
sum_negative_grad
*
params_buf
[
K
-
1
-
i
];
sum_negative_grad
+=
y_vals_grad_buf
[
K
-
i
];
}
}
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
64
)
params_grad_buf
[
0
]
+=
scale
*
scale_grad
;
__syncthreads
();
}
}
__syncthreads
();
// make sure all threads have written to grad_pos_add_buf and
if
(
threadIdx
.
x
<=
N
)
{
// grad_pos_mul_buf.
params_grad
[
blockIdx
.
y
][
c
][
threadIdx
.
x
]
=
params_grad_buf
[
threadIdx
.
x
];
int
block
=
blockIdx
.
z
*
gridDim
.
y
+
blockIdx
.
y
;
int
kernel_pos
=
threadIdx
.
x
;
if
(
kernel_pos
<
(
kH
*
kW
))
{
int
kh
=
kernel_pos
/
kW
,
kw
=
kernel_pos
%
kW
;
grad_pos_add
[
block
][
c
][
kh
][
kw
]
=
grad_pos_add_buf
[
kernel_pos
];
grad_pos_mul
[
block
][
c
][
kh
][
kw
]
=
grad_pos_mul_buf
[
kernel_pos
];
}
}
}
}
torch
::
Tensor
learned_nonlin_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
learned_nonlin_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
)
{
torch
::
Tensor
params
)
{
...
@@ -556,9 +568,7 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
...
@@ -556,9 +568,7 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
auto
scalar_t
=
input
.
scalar_type
();
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
// TODO: make this empty
torch
::
Tensor
output
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
torch
::
Tensor
output
=
torch
::
ones
({
B
,
C
,
T
},
opts
);
if
(
C
*
B
*
T
==
0
)
if
(
C
*
B
*
T
==
0
)
return
output
;
return
output
;
...
@@ -592,6 +602,8 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
...
@@ -592,6 +602,8 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
images_per_thread_block
==
1
,
images_per_thread_block
==
1
,
"Code error"
);
"Code error"
);
TORCH_CHECK
(
N
+
1
<=
THREADS_PER_BLOCK
,
"Values of N this large are not supported."
);
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
...
@@ -610,165 +622,89 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
...
@@ -610,165 +622,89 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
std
::
vector
<
torch
::
Tensor
>
learned_nonlin_backward_cuda
(
torch
::
Tensor
input
,
std
::
vector
<
torch
::
Tensor
>
learned_nonlin_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
params
,
torch
::
Tensor
grad_output
)
{
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
output_grad
.
dim
()
==
3
&&
output_grad
.
size
(
0
)
==
input
.
size
(
0
)
&&
output_grad
.
size
(
1
)
==
input
.
size
(
1
)
&&
output_grad
.
size
(
2
)
==
input
.
size
(
2
),
"output_grad and input have mismatched dim."
);
/*
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK
(
input
.
device
().
is_cuda
(),
"Input must be a CUDA tensor"
);
TORCH_CHECK
(
input
.
device
().
is_cuda
(),
"Input must be a CUDA tensor"
);
const int N = input.size(0),
TORCH_CHECK
(
output_grad
.
device
().
is_cuda
(),
"output_grad must be a CUDA tensor"
);
C = input.size(1) / 2,
TORCH_CHECK
(
params
.
device
().
is_cuda
(),
"Params must be a CUDA tensor"
);
H = input.size(2),
W = input.size(3),
kH = pos_add.size(1),
const
int
B
=
input
.
size
(
0
),
kW = pos_add.size(2);
C
=
input
.
size
(
1
),
TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
T
=
input
.
size
(
2
),
TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
N
=
params
.
size
(
1
)
-
1
;
TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
pos_mul.size(1) == kH && pos_mul.size(2) == kW,
TORCH_CHECK
(
N
>=
4
,
"This backward code requires N >= 4"
);
"Input sizes mismatch.");
TORCH_CHECK
(
N
<=
16
,
"This backward code currently requires N <= 16"
);
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
auto
scalar_t
=
input
.
scalar_type
();
auto
scalar_t
=
input
.
scalar_type
();
TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch");
TORCH_CHECK(grad_output.dim() == 4 && grad_output.size(0) == N
&& grad_output.size(1) == C && grad_output.size(2) == H
&& grad_output.size(3) == W);
// Work out the configuration to call the kernel with..
int patchH = std::min(H, kH), // output patch height
patchW = std::min(W, kW); // output patch width
// We don't want the height or width of the patch to be less than the kernel
// width, or the padding will make the input-patch size more than twice the
// output-patch size.
// We aim for the output-patch size to be more than 128; this is not something
// very exact, but it roughly corresponds to us wanting to have up to 4 threads
// per output pixel, and the limitation of 512 threads per thread-block which
// we impose so that we can run on architectures with little shared memory.
while (patchW < W && patchH * (patchW + 1) <= 128)
patchW++;
while(patchH < H && (patchH + 1) * patchW <= 128)
patchH++;
// We are assuming that the thread-block size can be as large as 512; this
// works even on old CUDA architectures.
int threads_per_pixel;
if (patchH * patchW * 4 <= 512 && (kH * kW) > 8)
threads_per_pixel = 4;
else if (patchH * patchW * 2 <= 512 && (kH * kW) > 4)
threads_per_pixel = 2;
else
threads_per_pixel = 1;
int threads_per_block = patchH * patchW * threads_per_pixel;
// round threads_per_block up to a multiple of 64. We need it to be
// equivalent to an even number of warps, because at one point we divide the
// threads into two halves and we want them to be an even number of warps.
threads_per_block = 64 * ((threads_per_block + 63) / 64);
{
// If it's possible to increase the patch width or height while not exceeding
// this number of threads, do so. (This is a small optimization).
int patchW_old = patchW;
while (patchH * (patchW + 1) * threads_per_pixel <= threads_per_block)
patchW++;
// If the above change to patchW did not actually reduce the number of patches
// needed to cover the image, gthen there is no point to the change; and it
// increases the shared-memory requirement, so revert it.
if ((W + patchW_old - 1) / patchW_old == (W + patchW - 1) / patchW)
patchW = patchW_old;
int patchH_old = patchH;
while ((patchH + 1) * patchW * threads_per_pixel <= threads_per_block)
patchH++;
if ((H + patchH_old - 1) / patchH_old == (H + patchH - 1) / patchH)
patchH = patchH_old;
}
torch
::
Tensor
input_grad
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
int threads_per_kernel_pos = 1;
if
(
C
*
B
*
T
==
0
)
{
while (threads_per_kernel_pos < 32 &&
return
std
::
vector
<
torch
::
Tensor
>
({
input_grad
,
threads_per_kernel_pos * 2 * kH * kW <= threads_per_block)
torch
::
empty
({
C
,
N
+
1
})});
threads_per_kernel_pos *= 2;
// dimensions of padded patches
int ppatchH = patchH + kH - 1,
ppatchW = patchW + kW - 1,
ppatch_size = ppatchH * ppatchW;
int buffer_numel = 4 * (kH * kW) + 3 * ppatch_size + threads_per_block;
int num_patches_H = (H + patchH - 1) / patchH,
num_patches_W = (W + patchW - 1) / patchW,
num_patches = num_patches_H * num_patches_W;
// gridDim.x == C.
int num_blocks_patch = 1, // gridDim.y. should not be more
num_blocks_batch = 1; // gridDim.z
// We have a rough target of no more than 256 thread-groups.
while (C * num_blocks_patch * 2 <= 256 &&
num_blocks_patch * 2 <= num_patches)
num_blocks_patch *= 2;
if (C * num_patches <= 512)
num_blocks_patch = num_patches;
while (C * num_blocks_patch * num_blocks_batch * 2 <= 256 &&
num_blocks_batch * 2 <= N)
num_blocks_batch *= 2;
assert(num_blocks_patch <= num_patches && num_blocks_batch <= N);
assert(patchH * patchW * threads_per_pixel <= threads_per_block);
assert(kH * kW * threads_per_kernel_pos <= threads_per_block);
static int debug_count = 50;
if (debug_count > 0) {
debug_count--;
std::cout << "[backward:] N,C,H,W=" << N << "," << C << "," << H << "," << W
<< "; kW,kH=" << kW << "," << kH
<< "; patchH,patchW=" << patchH << ","
<< patchW << ", num_blocks_patch="
<< num_blocks_patch << ", num_blocks_batch="
<< num_blocks_batch
<< ", threads_per_pixel=" << threads_per_pixel
<< ", threads_per_kernel_pos=" << threads_per_kernel_pos
<< ", threads_per_block=" << threads_per_block
<< ", buffer_numel=" << buffer_numel
<< std::endl;
}
}
int num_blocks = num_blocks_patch * num_blocks_batch;
int
images_per_thread_block
=
1
;
while
(
images_per_thread_block
*
2
*
T
<=
THREADS_PER_BLOCK
&&
torch::Tensor grad_input = torch::zeros({N, 2*C, H, W},
images_per_thread_block
*
2
*
N
<=
THREADS_PER_BLOCK
)
torch::TensorOptions().dtype(scalar_t).device(input.device())),
images_per_thread_block
*=
2
;
grad_pos_add = torch::zeros({num_blocks, C, kH, kW},
torch::TensorOptions().dtype(scalar_t).device(input.device())),
int
grid_dim_y
=
1
;
grad_pos_mul = torch::zeros({num_blocks, C, kH, kW},
// If the number of channels is quite small (<128) we can launch more thread
torch::TensorOptions().dtype(scalar_t).device(input.device()));
// groups, splitting on the batch index.
while
(
C
*
grid_dim_y
<
128
)
grid_dim_y
*=
2
;
dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
// blockDim is scalar, just threads_per_block.
// B_reduced is the max number of thread-groups per channel that would have
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel_backward", ([&] {
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
learned_nonlin_kernel_backward<scalar_t><<<gridDim, threads_per_block,
// launching kernels with nothing to do.
sizeof(scalar_t) * buffer_numel,
int
B_reduced
=
(
B
+
images_per_thread_block
-
1
)
/
images_per_thread_block
;
at::cuda::getCurrentCUDAStream()>>>(
if
(
grid_dim_y
>
B_reduced
)
input.packed_accessor32<scalar_t, 4>(),
grid_dim_y
=
B_reduced
;
pos_add.packed_accessor32<scalar_t, 3>(),
pos_mul.packed_accessor32<scalar_t, 3>(),
int
shared_mem_numel
=
2
*
N
+
3
;
grad_output.packed_accessor32<scalar_t, 4>(),
grad_input.packed_accessor32<scalar_t, 4>(),
if
(
false
)
grad_pos_add.packed_accessor32<scalar_t, 4>(),
std
::
cout
<<
"C,B,T,N = "
<<
C
<<
","
<<
B
<<
","
<<
T
<<
","
<<
N
grad_pos_mul.packed_accessor32<scalar_t, 4>(),
<<
", images_per_thread_block = "
<<
images_per_thread_block
patchH,
<<
", grid_dim_y = "
<<
grid_dim_y
patchW,
<<
"
\n
"
;
threads_per_pixel,
threads_per_kernel_pos);
TORCH_CHECK
(
THREADS_PER_BLOCK
/
images_per_thread_block
>=
T
||
images_per_thread_block
==
1
,
"Code error"
);
torch
::
Tensor
params_grad
=
torch
::
empty
({
grid_dim_y
,
C
,
N
+
1
},
opts
);
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"learned_nonlin_backward_kernel"
,
([
&
]
{
learned_nonlin_backward_kernel
<
scalar_t
><<<
gridDim
,
THREADS_PER_BLOCK
,
sizeof
(
scalar_t
)
*
shared_mem_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
3
>
(),
params
.
packed_accessor32
<
scalar_t
,
2
>
(),
output_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
input_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
params_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
images_per_thread_block
);
}));
}));
grad_pos_add = at::sum(grad_pos_add, {0});
grad_pos_mul = at::sum(grad_pos_mul, {0});
return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul}); */
params_grad
=
at
::
sum
(
params_grad
,
{
0
});
return
std
::
vector
<
torch
::
Tensor
>
({
input_grad
,
params_grad
});
}
}
torch_learned_nonlin/learned_nonlin_test.py
View file @
74897fd5
...
@@ -63,7 +63,7 @@ def test_learned_nonlin_deriv():
...
@@ -63,7 +63,7 @@ def test_learned_nonlin_deriv():
device
=
torch
.
device
(
'cuda:0'
)
device
=
torch
.
device
(
'cuda:0'
)
y2
=
learned_nonlin
(
x
.
to
(
device
),
params
.
to
(
device
),
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
y2
=
learned_nonlin
(
x
.
to
(
device
),
params
.
to
(
device
),
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
print
(
"Checking CUDA is same"
)
print
(
"Checking CUDA is same"
)
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-0
6
):
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-0
5
):
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
, max-diff =
{
(
y2
-
y
).
abs
().
max
()
}
"
)
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
, max-diff =
{
(
y2
-
y
).
abs
().
max
()
}
"
)
assert
(
0
)
assert
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment