Commit 74897fd5 authored by Daniel Povey's avatar Daniel Povey
Browse files

Test sometimes failing, think it's an older problem.

parent 06e369c9
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
blockDim.x, to be used as a temporary within this function. blockDim.x, to be used as a temporary within this function.
val: The value to be summed val: The value to be summed
Return: Return:
Threads where blockDim.x % threads_per_tile == 0 will return the sum: Threads where threadIdx.x % threads_per_tile == 0 will return the sum:
\sum_{i=0}^{threads_per_tile-1} [val in thread threadIdx.x + i] \sum_{i=0}^{threads_per_tile-1} [val in thread threadIdx.x + i]
Return value in other threads is undefined. The return value in other threads is undefined.
*/ */
template <typename scalar_t> template <typename scalar_t>
__forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile, __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...@@ -43,8 +43,9 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile, ...@@ -43,8 +43,9 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
/* /*
Forward of learned_nonlin. Each thread group handles a single channel (equal Forward of learned_nonlin. Each thread group handles a single channel (channel
to blockIdx.x); the gridDim is (C, nb) where 1 <= nb <= B (nb relates to the batch). c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
Template args: Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half. scalar_t: the floating-point type, e.g. float, double, maybe half.
...@@ -71,7 +72,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile, ...@@ -71,7 +72,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
.. this is used for a small optimization. .. this is used for a small optimization.
This kernel is allocated with `extern_buf` containing enough memory This kernel is allocated with `extern_buf` containing enough memory
to store 2*N values of type scalar_t. to store 2*N + 3 values of type scalar_t.
The blockDim must equal (THREADS_PER_BLOCK, 1, 1) The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
...@@ -80,9 +81,10 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile, ...@@ -80,9 +81,10 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
1 <= gridDim.y <= B, where B is the number of blocks 1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z == 1 gridDim.z == 1
When we invoke this kernel, we'll invoke it as: When we invoke this kernel, we'll invoke it as:
learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>> learned_nonlin_kernel<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`: where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3) bytesShared = sizeof(shared_t) * (2N + 3)
We also require N + 1 <= THREADS_PER_BLOCK.
*/ */
extern __shared__ int extern_buf[]; extern __shared__ int extern_buf[];
...@@ -98,31 +100,33 @@ void learned_nonlin_kernel( ...@@ -98,31 +100,33 @@ void learned_nonlin_kernel(
C = input.size(1), C = input.size(1),
T = input.size(2), T = input.size(2),
N = params.size(1) - 1, N = params.size(1) - 1,
K = N / 2; // Note: N and K are powers fo 2, with K >= 1. K = N / 2; // Note: N and K are powers of 2, with K >= 1.
const int c = blockIdx.x; // c is channel index const int c = blockIdx.x; // c is channel index
scalar_t *y_vals = (scalar_t*) extern_buf, // [N], actually there are two scalar_t *y_vals = (scalar_t*) extern_buf, // [N], actually there are 3
// spaces between here and // spaces between here and
// `params_buf` for storing scale // `params_buf` for storing scale
// and inv_scale. // and inv_scale and l == params[c][0].
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Caution: contains params[c][1] through params[c][N]. *params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Caution: contains params[c][1] through params[c][N].
// params_buf[-1] contains params[c][0] == log of scale; // params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale. // params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// Load parameters // Load parameters
for (int n = threadIdx.x; n <= N; n += THREADS_PER_BLOCK) { if (threadIdx.x <= N)
params_buf[n - 1] = params[c][n]; params_buf[threadIdx.x - 1] = params[c][threadIdx.x];
}
__syncthreads(); __syncthreads();
// The easiest way to understand this code is to compare it with the CPU code // The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp. // in learned_nonlin_cpu.cpp.
if ((((int)threadIdx.x & ~(int)32)) == 0) { // TODO: replace this with easier-to-understand code.
// threadIdx.x == 0 or 32. These are in separate warps so we can if ((((int)threadIdx.x & ~(int)64)) == 0) {
// allow them to do separate jobs. This code takes linear time in K which // threadIdx.x == 0 or 64 (we choose 64 because it's >= the max known warp
// is not at all ideal and could be improved if K is largish, but it shouldn't // size). These are in separate warps so we can allow them to do separate
// dominate the total time taken if we are processing a lot of data; // jobs. This code takes linear time in K which is not at all ideal and
// and anyway, we doubt that K will be need to be more than 4 or 8 or so, // could be improved if K is largish, but it shouldn't dominate the total
// so the potential savings are quite small. // time taken if we are processing a lot of data; and anyway, we doubt that
// K will be need to be more than 4 or 8 or so, so the potential savings are
// quite small.
scalar_t scale = exp(params_buf[-1]), scalar_t scale = exp(params_buf[-1]),
inv_scale = 1.0 / scale; inv_scale = 1.0 / scale;
params_buf[-2] = scale; // both threads write these but it's OK, it's the params_buf[-2] = scale; // both threads write these but it's OK, it's the
...@@ -137,7 +141,7 @@ void learned_nonlin_kernel( ...@@ -137,7 +141,7 @@ void learned_nonlin_kernel(
if (threadIdx.x == 0) { // sum_positive if (threadIdx.x == 0) { // sum_positive
sign = 1; sign = 1;
Koffset = K; Koffset = K;
} else { // threadIdx.x == 32. sum_negative. } else { // threadIdx.x == 64. sum_negative.
scale *= -1; // this is a local variable.. scale *= -1; // this is a local variable..
sign = -1; sign = -1;
Koffset = K - 1; Koffset = K - 1;
...@@ -155,11 +159,11 @@ void learned_nonlin_kernel( ...@@ -155,11 +159,11 @@ void learned_nonlin_kernel(
scalar_t inv_scale = params_buf[-3]; scalar_t inv_scale = params_buf[-3];
int T_inc = THREADS_PER_BLOCK / images_per_thread_block, int T_inc = THREADS_PER_BLOCK / images_per_thread_block,
image_offset = threadIdx.x / T_inc, b_offset = threadIdx.x / T_inc, // offset within batch
t_start = threadIdx.x % T_inc; t_start = threadIdx.x % T_inc;
for (int b = blockIdx.y * images_per_thread_block + image_offset; for (int b = blockIdx.y * images_per_thread_block + b_offset; b < B;
b < B; b += gridDim.y * images_per_thread_block) { b += gridDim.y * images_per_thread_block) {
// We do "t += THREADS_PER_BLOCK" instead of t += (THREADS_PER_BLOCK / // We do "t += THREADS_PER_BLOCK" instead of t += (THREADS_PER_BLOCK /
// images_per_thread_block) as a small optimization because the only case we // images_per_thread_block) as a small optimization because the only case we
// really need to loop is when images_per_thread_block == 1:a we only let // really need to loop is when images_per_thread_block == 1:a we only let
...@@ -172,367 +176,375 @@ void learned_nonlin_kernel( ...@@ -172,367 +176,375 @@ void learned_nonlin_kernel(
else if (x_trunc >= N) x_trunc = N - 1; else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x_trunc;
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < N.
scalar_t y = (x - n) * params_buf[n] + y_vals[n]; output[b][c][t] = (x - n) * params_buf[n] + y_vals[n];
output[b][c][t] = y;
} }
} }
} }
/* /*
Backward of learned_nonlin. Each thread group handles a single channel (equal Summing reduction within a one-dimensional thread block, but with a
to blockIdx.x), and loops over patches of the output and over the image n stride of N, so that we separately sum up the values of all threads with
within the batch (different thread groups may be responsible for different threadIdx.x % N == 0, with threadIdx.x % N == 1, and so on. At the end,
subsets of patches and/or images, see docs of gridDim below). threads with 0 <= threadIdx.x < N contain the sums.
If you want to understand this code, you should first understand the forward So this is like tiled summing reduction except that the tiles are
code. Here are some points to understand how this works: interspersed with each other.
First, understand the difference between the patch of size patchH by
patchW, which is the basic patch size that is related to the blockDim.x, Args:
and the padded patch size ppatchH and ppatchW, where: N: The number we sum modulo (must be a power of 2 with
ppatchH = patchH + kH - 1 1 <= N <= blockDim.x), i.e. all threads with
ppatchW = patchW + kW - 1. threadIdx.x % N == n for some 0 <= n < N have `val` summed.
buf: Pointer to the start of a __shared__ buffer of size
In the forward pass, we dealt with a patch of output and a padded patch of blockDim.x, to be used as a temporary within this function.
input. In this backward-pass code, when computing the `grad_input` we deal val: The value to be summed
with a patch of input and a padded patch of output (this ensures that Return:
different thread-blocks write to distinct patches of `grad_input`). But this Threads where threadIdx.x < N will return the sums (over the threads with
approach is not sufficient to update `grad_pos_add` and `grad_pos_mul`, the same value of threadIdx.x % N);
because it's possible for elements of the zero-padding of `input` to the return value in other threads is undefined.
contribute to `grad_pos_add` and `grad_pos_mul`. So when computing the */
gradients for those quantities, we actually use a padded input patch and an template <typename scalar_t>
un-padded output patch. This requires that we load into shared memory the __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
padded versions of both input and grad_output. __volatile__ scalar_t *buf,
scalar_t val) {
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
for (int i = blockDim.x / 2; i >= N; i /= 2) {
buf[threadIdx.x] = val;
__syncthreads();
if (threadIdx.x < i)
val += buf[threadIdx.x + i];
}
return val; // Only threads with threadIdx.x < N will return the full sums of
// their groups.
}
/*
Backward of learned_nonlin. Each thread group handles a single channel (channel
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
Template args: Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half. scalar_t: the floating-point type, e.g. float, double, maybe half.
Args: Args:
input [in]: input image, shape (N, 2*C, H, W) input: input image, shape (B, C, T) where B is batch size, C is
pos_add [in]: positional encoding, additive part, shape (C, kH, kW) the number of channels and T is the time axis. (For more-than-1d
pos_mul [in]: positional encoding, multiplicative part, shape (C, kH, kW) convolution setups, T would really be more than 1 axis, reshaped).
grad_output [in]: the gradient w.r.t. the output of the forward pass, shape (N, C, H, W) params: of shape (C, N+1) where N is the number of linear regions in the
grad_input [out]: the gradient w.r.t. the input, of shape N, 2*C, H, W piecewise linear function; params[c][0] is l which is
grad_pos_add [out]: the gradient w.r.t. pos_add, indexed [block][c][kh][kw], a log scale parameter that dictates how far apart
of shape num_blocks, C, kH, kW, the discontinuities in the piecewise linear function are,
where `block` is an index we'll later sum over, that corresponds to and params[c][n+1] for 0 <= n < N are the derivatives
the identity of the thread-block (except, not including the channel of the linear parts of the piecewise linear function.
dimension == gridDim.x). So, block == blockIdx.z * gridDim.y + blockIdx.y, The discontinuities of the function are at:
and num_blocks == gridDim.y * gridDim.z. exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
grad_pos_mul [out]: the gradient w.r.t. pos_mul, like grad_pos_add above. output: The transformed input, shape (B , C, T)
patchH: the height of the patch size this kernel operates on (prior to padding) images_per_thread_block: The number of images processed by each thread
patchW: the width of the patch size this kernel operates on (prior to padding) block. The calling code must guarantee that this is a power
threads_per_pixel: the number of threads assigned to compute each pixel of 2, and that EITHER:
of grad_input. Require patchH * patchW * threads_per_pixel <= blockDim.x (THREADS_PER_BLOCK / images_per_thread_block >= T AND
and threads_per_pixel must be a power of 2 in the interval [1,32]. THREADS_PER_BLOCK / images_per_thread_block >= N),
threads_per_kernel_pos: the number of threads assigned to compute each kernel OR
position of grad_pos_add and grad_pos_mul. images_per_thread_block == 1
Require kH * kW * threads_per_kernel_pos <= blockDim.x, .. this is used for a small optimization.
and threads_per_kernel_pos must be a power of 2 in the interval [1,32].
This requires that kH * kW must not be greater than 1024. ALSO,
Note: kH and kW must both be odd so that it's clear how to zero-pad. This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
The thread-block should have one dimension (x); see docs for threads_per_pixel
and threads_per_kernel_pos for requirements on blockDim.x. Also, blockDim.x The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
must be an exact multiple of 64, so we can divide the threads by 2 and they
will be in different warps.
The requirements on the grid dimension are: The requirements on the grid dimension are:
gridDim.x == num-channels C (required) gridDim.x == num-channels C (required)
gridDim.y <= num-patches per image (recommended) 1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z <= batch-size N (recommended) gridDim.z == 1
When we invoke this kernel, we'll invoke it as: When we invoke this kernel, we'll invoke it as:
learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>> learned_nonlin_backward_kernel<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`: where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
We also require that N <= THREADS_PER_BLOCK (for best performance,
N should be quite small, like no larger than 8 or so).
We also require 4 <= N <= 16 for this code!
bytesShared = sizeof(shared_t) * numel, where
numel = 4 * (kH * kW) + 3 * (ppatchH * ppatchW) + blockDim.x
*/ */
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void learned_nonlin_kernel_backward( void learned_nonlin_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 4> input, // N, 2*C, H, W torch::PackedTensorAccessor32<scalar_t, 3> input, // B, C, T, i.e. batch, channels, time
torch::PackedTensorAccessor32<scalar_t, 3> pos_add, // C, kH, kW torch::PackedTensorAccessor32<scalar_t, 2> params, // C, N + 1
torch::PackedTensorAccessor32<scalar_t, 3> pos_mul, // C, kH, kW torch::PackedTensorAccessor32<scalar_t, 3> output_grad, // B, C, T
torch::PackedTensorAccessor32<scalar_t, 4> grad_output, // N, C, H, W torch::PackedTensorAccessor32<scalar_t, 3> input_grad, // B, C, T
torch::PackedTensorAccessor32<scalar_t, 4> grad_input, // N, 2*C, H, W // params_grad is of dim (gridDim.y, C, N + 1), we'll sum over dim 0.
torch::PackedTensorAccessor32<scalar_t, 4> grad_pos_add, // block, C, kH, kW, see above for `block` torch::PackedTensorAccessor32<scalar_t, 3> params_grad,
torch::PackedTensorAccessor32<scalar_t, 4> grad_pos_mul, // block, C, kH, kW, see above for `block` int images_per_thread_block) { // B, C, T
int patchH, // non-padded patch height
int patchW, // non-padded patch width const int B = input.size(0),
int threads_per_pixel, C = input.size(1),
int threads_per_kernel_pos) { T = input.size(2),
N = params.size(1) - 1,
const int H = input.size(2), K = N / 2; // Note: N and K are powers fo 2, with K >= 1.
W = input.size(3),
kH = pos_add.size(1), const int c = blockIdx.x; // c is channel index
kW = pos_add.size(2),
npatchH = (H + patchH - 1) / patchH, // num patches in vertical dim scalar_t *y_vals = (scalar_t*) extern_buf, // [N], actually there are three
npatchW = (W + patchW - 1) / patchW, // num patches in horizontal dim // spaces between here and
npatch = npatchH * npatchW; // total number of patches per image // `params_buf` for storing scale
// and inv_scale and l == params[c][0].
// Channel index. *params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Caution: contains params[c][1] through params[c][N].
const int c = blockIdx.x; // params_buf[-1] contains params[c][0] == log of scale;
// We don't need to check the range of `c` because we set gridDim.x to the // params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// exact number of channels.
scalar_t x_residual_buf[THREADS_PER_BLOCK]; // x_residual, with 0 <=
const int ppatchH = patchH + kH - 1, // ppatchH is the padded patch height. // x_residual < 1 for interior
ppatchW = patchW + kW - 1, // ppatchW is the padded patch width // regions, is the residual part
patch_size = patchH * patchW, // un-padded patch size // of the scaled input, after
ppatch_size = ppatchH * ppatchW; // padded patch size // subtracting the integer part.
scalar_t output_grad_buf[THREADS_PER_BLOCK];
// `extern_buf` is general-purpose shared memory, which we'll divide between char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`, this stores
// various buffers. // the integer value 0 <= n < N which
// determines which piece of the piecewise
// these are pointers to __shared__ memory; the compiler should // linear function we are in.
// be able to figure this out.
scalar_t // this_params_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
*pos_add_buf = (scalar_t*)extern_buf, // pos_add positional-encoding / kernel parameters, // linear interval) corresponding to n == threadIdx.x % N. For example, if
// indexed [kh*kW + kw] where kh and kw are vertical // threadIdx.x == 0, this thread's gradient corresponds to the left-most
// and horizontal positions in the kernel. // linear interval.
*pos_mul_buf = pos_add_buf + (kH * kW), // pos_mul positional-encoding / kernel parameters, scalar_t this_params_grad = 0.0,
// indexed [kh*kW + kw] where kh and kw are vertical this_y_vals_grad = 0.0;
// and horizontal positions in the kernel.
*src_img_buf = pos_mul_buf + (kH * kW), // version of input image that relates to source position, // Load parameters
// of size [ppatch_size], indexed [h*ppatchW + w], if (threadIdx.x <= N)
// where the 'h' and 'w' indexes are into the zero-padded input params_buf[threadIdx.x - 1] = params[c][threadIdx.x];
// image.
*dest_img_buf = src_img_buf + ppatch_size, // version of input image that relates to destinatioon position
*grad_output_buf = dest_img_buf + ppatch_size, // output gradient for padded patch, indexed [h*ppatchW + w]
*grad_pos_add_buf = grad_output_buf + ppatch_size, // total grad for pos_add for this thread block, indexed [kh*kW + kw]
*grad_pos_mul_buf = grad_pos_add_buf + (kH * kW), // total grad for pos_mul for this thread block, indexed [kh*kW + kw]
*reduce_buf = grad_pos_mul_buf + (kH * kW); // buffer for reduction over threads, size == blockDim.x
// pos_in_patch will be interpreted as h_in_patch * patchW + w_in_patch.
int pos_in_patch = threadIdx.x / threads_per_pixel;
// Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// in pos_add_buf and pos_mul_buf; zero the corresponding gradient buffers.
// We know that blockDim.x >= kH * kW, see threads_per_kernel_pos.
for (int i = threadIdx.x % (blockDim.x / 2); i < kH * kW; i += (blockDim.x / 2)) {
int kh = i / kW, kw = i % kW;
if (threadIdx.x < blockDim.x / 2) { // First half of threads take care of pos_add..
pos_add_buf[i] = pos_add[c][kh][kw];
grad_pos_add_buf[i] = 0.0;
} else { // Second half take care of pos_mul... there is no warp divergence
// because we make sure blockDim.x is a multiple of 64.
pos_mul_buf[i] = pos_mul[c][kh][kw];
grad_pos_mul_buf[i] = 0.0;
}
}
// n is the index within the batch of images. Loop to make sure we cover all
// images in the batch. input.size(0) is the batch size N. All threads in
// the thread-block loop the same number of times.
for (int n = blockIdx.z; n < input.size(0); n += gridDim.z) {
// Loop over the patch within the output image. All threads in the
// thread-block loop the same number of times.
for (int patch_idx = blockIdx.y; patch_idx < npatch; patch_idx += gridDim.y) {
// (patch_h_offset, patch_w_offset) are the (vertical, horizontal) indexes
// of the lowest-numbered pixel in the *un-padded* patch that this thread
// block is responsible for. (We'll actualy be loading the padded patches
// into memory, so be careful).
int patch_h_offset = (patch_idx / npatchW) * patchH,
patch_w_offset = (patch_idx % npatchW) * patchW;
// This __syncthreads() is only necessary if we have already looped at
// least once over n or patch_idx: it's in case other threads are still
// using the `src_img_buf` or `dst_img_buf` buffers for a previous patch.
__syncthreads(); __syncthreads();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
// Load the 'src' and 'dest' versions of the padded patch into // This next block computes `y_vals`.
// shared-memory buffers, and also the output gradient. if ((((int)threadIdx.x & ~(int)32)) == 0) {
for (int i = threadIdx.x % (blockDim.x / 2); // threadIdx.x == 0 or 32. These are in separate warps so we can
i < ppatch_size; i += (blockDim.x / 2)) { // allow them to do separate jobs. This code takes linear time in K which
int h_in_ppatch = i / ppatchW, // is not at all ideal and could be improved if K is largish, but it shouldn't
w_in_ppatch = i % ppatchW; // dominate the total time taken if we are processing a lot of data;
int h = patch_h_offset + h_in_ppatch - (kH / 2), // kH / 2 is offset due to padding // and anyway, we doubt that K will be need to be more than 4 or 8 or so,
w = patch_w_offset + w_in_ppatch - (kW / 2); // so the potential savings are quite small.
scalar_t scale = exp(params_buf[-1]),
if (threadIdx.x < blockDim.x / 2) { // The first half of the threads of the block inv_scale = 1.0 / scale;
// load `input` params_buf[-2] = scale; // both threads write these but it's OK, it's the
scalar_t src_val = scalar_t(0), // same value.
dest_val = scalar_t(0); params_buf[-3] = inv_scale;
if ((unsigned int)h < (unsigned int)H && // h >= 0 && h < H int sign,
(unsigned int)w < (unsigned int)W) { // w >= 0 && w < W Koffset; // Koffset == K for threads handling sum_positive and K - 1
int C = grad_output.size(1); // for threads handling sum_negative, see
src_val = input[n][c][h][w]; // learned_nonlin_cpu.cpp for reference code. This would be K
dest_val = input[n][c + C][h][w]; // + 1 and K respectively, except our params_buf has its index
// shifted by one versus params.
if (threadIdx.x == 0) { // sum_positive
sign = 1;
Koffset = K;
} else { // threadIdx.x == 32. sum_negative.
scale *= -1; // this is a local variable..
sign = -1;
Koffset = K - 1;
} }
src_img_buf[i] = src_val; scalar_t sum = 0.0;
dest_img_buf[i] = dest_val; for (int i = 0; i < K; i++) {
} else { // second half of threads load `grad_output`. We require int isign = i * sign;
// blockDim.x be an even multiple of the warp size, so there y_vals[K + isign] = sum * scale;
// is no warp divergence here. sum += params_buf[Koffset + isign];
scalar_t grad_output_val = scalar_t(0);
if ((unsigned int)h < (unsigned int)H &&
(unsigned int)w < (unsigned int)W)
grad_output_val = grad_output[n][c][h][w];
grad_output_buf[i] = grad_output_val;
} }
if (threadIdx.x != 0) // sum_negative
y_vals[0] = sum * scale;
} }
// make sure all threads haave written to `src_img_buf`, `dest_img_buf` and
// `grad_output_buf`.
__syncthreads(); __syncthreads();
scalar_t inv_scale = params_buf[-3];
scalar_t grad_input_src_sum = 0.0, // grad for channel c, for our pixel int T_inc = THREADS_PER_BLOCK / images_per_thread_block,
// of `input` (contribution of this b_offset = threadIdx.x / T_inc; // offset within batch
// thread)
grad_input_dest_sum = 0.0; // grad for channel c + C, for our pixel for (int b = blockIdx.y * images_per_thread_block + b_offset; b < B;
// of `input` (contribution of this thread) b += gridDim.y * images_per_thread_block) {
if (pos_in_patch < patch_size) {
// This block computes `grad_input_src_sum` and `grad_input_dest_sum` // The following will loop just once if images_per_thread_block > 1. If
// The num-threads for the backward kernel may not be an exact multiple // images_per_thread_block == 1 and T > THREADS_PER_BLOCK, we will loop
// of patch_size, wo we need the if-guard. // multiple times. We want to keep all threads active so that output_grad
// will be set to zero for excess threads, and thus won't contribute to
int h_in_patch = pos_in_patch / patchW, // this_params_grad or this_y_vals_grad.
w_in_patch = pos_in_patch % patchW, for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) {
h_in_ppatch = h_in_patch + kH / 2, int t = threadIdx.x % T_inc + t_offset;
w_in_ppatch = w_in_patch + kW / 2, scalar_t this_output_grad = 0.0, x = 0.0;
pos_in_ppatch = h_in_ppatch * ppatchW + w_in_ppatch; if (t < T)
this_output_grad = output_grad[b][c][t];
// this_dest_val is the `destination` version of our current pixel; this
// is an input. It gets added to each src pixel, prior to the relu, in // The reason we use t % T here rather than only invoking this in some
// the loop below. // threads, is so that the un-needed threads will have a similar
// this_src_val is the `src` version of our current pixel; it contributes // distribution over 'n' to the needed threads, which will hopefully avoid
// to the outputs of other pixels. // excessive work for some particular 'n' value if too many x values had
scalar_t this_dest_val = dest_img_buf[pos_in_ppatch], // the same 'n'. It might be better to set n to an invalid value for
this_src_val = src_img_buf[pos_in_ppatch]; // out-of-range threads, but as it is, if we are to properly handle
// N==16 we don't have enough bits available in `src_indexes` to do this.
for (int pos_in_kernel = threadIdx.x % threads_per_pixel; x = input[b][c][t % T] * inv_scale + K;
pos_in_kernel < (kH * kW);
pos_in_kernel += threads_per_pixel) { output_grad_buf[threadIdx.x] = this_output_grad;
int h_in_kernel = pos_in_kernel / kW, scalar_t x_trunc = x;
w_in_kernel = pos_in_kernel % kW; if (x_trunc < 0) x_trunc = 0;
else if (x_trunc >= N) x_trunc = N - 1;
// This is actually more like cross-correlation, as we don't have a // C++ rounds toward zero.
// negative sign on the h and w indexes in the kernel. int n = (int)x_trunc;
int src_h_in_ppatch = h_in_patch + h_in_kernel, n_buf[threadIdx.x] = (char)n;
src_w_in_ppatch = w_in_patch + w_in_kernel;
int src_pos_in_ppatch = src_h_in_ppatch * ppatchW + src_w_in_ppatch; scalar_t x_residual = x - n;
x_residual_buf[threadIdx.x] = x_residual;
scalar_t src_val = src_img_buf[src_pos_in_ppatch],
pos_add_val = pos_add_buf[pos_in_kernel], // OK, at this point, 0 <= min < N.
pos_mul_val = pos_mul_buf[pos_in_kernel]; // The forward code did:
scalar_t relu = (src_val + this_dest_val + pos_add_val); // output[b][c][t] = (x - n) * params_buf[n] + y_vals[n];
if (relu >= 0.0) {
scalar_t this_grad_output = grad_output_buf[pos_in_ppatch]; if (t < T)
grad_input_dest_sum += this_grad_output * pos_mul_val; input_grad[b][c][t] = this_output_grad * params_buf[n];
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N),
this_n = threadIdx.x & (N-1); // == threadIdx.x % N.
// this_n is the n value that this thread accumulates gradients for;
// it is responsible for output_grads in the block of threads
// from this_block_start to this_block_start+N-1.
// SYNC POINT At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is considered not
// future-proof). Threads above have written to n_buf, and threads below
// will read from it; but we don't need to explicitly synchronize for now
// because the reads/writes are among threads in a group of N threads with
// (4 <= N <= 16); and 16 is less than the warp size which is 32 or 64.
// src_indexes will contain up to 16 16-bit numbers, stored starting in its
// least significant bits. It will store all the offsets within this
// block of N, where the 'n' value equals this_n.
uint64_t src_indexes = 0;
// num_src is the number of numbers in `src_indexes`. We need to store a
// separate counter because zero is a valid index and if we are to support
// N == 16 we don't have bits to spare in src_indexes to store some kind
// of marker.
int num_src = 0;
// This loop always does N statements, but they should be relatively fast
// ones since the computation per n value is minimal and there is little
// I/O. We are figuring out the subset of our block of N elements,
// which this particular thread value is responsible for (because they
// have n == this_n), and storing them in `src_indexes` and `num_src`.
for (int i = 0; i < N; i += 4) {
uint32_t n_block_of_4 = *reinterpret_cast<uint32_t*>(n_buf + this_block_start + i);
#pragma unroll
for (int j = 0; j < 4; ++j) {
// CUDA is little endian
char n = (char)(n_block_of_4 >> (8*j));
if (n == this_n) {
// We require that N <= 16, so 4 bits is enough to store src_idx.
src_indexes = (src_indexes << 4) | (i + j);
++num_src;
} }
// To compute a contribution to "this_input_src_grad", we need to // Note: if, for out-of-range threads, we had values not in [0..N-1] in
// consider the contribution to the destination pixel that it would // n_buf they won't end up mattering even though they are read here,
// have contributed to with this same offset. // because they won't equal this_n. For values 0 <= n < N originating
// We have to flip the offsets: instead of "+ h_in_kernel", // in out-of-range threads, the value won't matter because the
// we use (kH - 1) - h_in_kernel,. // corresponding value in output_grad_buf will be zero.
int dest_h_in_ppatch = h_in_patch + (kH - 1) - h_in_kernel,
dest_w_in_ppatch = w_in_patch + (kW - 1) - w_in_kernel,
dest_pos_in_ppatch = dest_h_in_ppatch * ppatchW + dest_w_in_ppatch;
scalar_t dest_val = dest_img_buf[dest_pos_in_ppatch];
relu = dest_val + this_src_val + pos_add_val;
if (relu >= 0.0) {
scalar_t dest_grad_output = grad_output_buf[dest_pos_in_ppatch];
grad_input_src_sum += dest_grad_output * pos_mul_val;
} }
} }
// While num_src could theoretically be as large as N, the hope is that no
// thread in any given warp actually loops that many times. Once all
// threads in the warp are finished looping, we can continue. It is OK
// for different warps to get out of sync here; we could be looping over a
// number of images, and the hope is that different warps will reach the
// end of the outer loop at around the same time because their variations
// in speed will average out.
for (; num_src > 0; --num_src, src_indexes >>= 4) {
int src_idx = src_indexes & 0xF,
src_thread = this_block_start + src_idx;
scalar_t output_grad = output_grad_buf[src_thread],
x_residual = x_residual_buf[src_thread];
// Backprop for: output = x_residual * params_buf[n] + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_params_grad += output_grad * x_residual;
this_y_vals_grad += output_grad;
} }
// Aggregate `grad_input_src_sum` over threads, if needed; and write the
// result to `grad_input`.
// h and w are un-padded indexes into the entire image.
int h = patch_h_offset + pos_in_patch / patchW,
w = patch_w_offset + pos_in_patch % patchW;
if (h < H && w < W) {
grad_input_src_sum = tiled_warp_reduce_sum(threads_per_pixel,
reduce_buf,
grad_input_src_sum);
grad_input_dest_sum = tiled_warp_reduce_sum(threads_per_pixel,
reduce_buf,
grad_input_dest_sum);
if (threadIdx.x % threads_per_pixel == 0) {
grad_input[n][c][h][w] = grad_input_src_sum;
int C = grad_output.size(1);
grad_input[n][c + C][h][w] = grad_input_dest_sum;
} }
} }
// OK, we are done computing grad_input for this patch. Now __syncthreads(); // sync threads because we are about to re-use
// we need to contribute the contributions to grad_pos_add_buf // output_grad_buf for reduction.
// and grad_pos_mul_buf for this patch.
// 0 <= pos_in_kernel < (kH * kW). this_params_grad = strided_reduce_sum(N, output_grad_buf, this_params_grad);
int pos_in_kernel = threadIdx.x / threads_per_kernel_pos; this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad);
scalar_t this_grad_pos_add = 0.0,
this_grad_pos_mul = 0.0; __syncthreads(); // sync threads because we are about to re-use
if (pos_in_kernel < (kH * kW)) { // output_grad_buf.
int kh = pos_in_kernel / kW,
kw = pos_in_kernel % kW; // Re-use some buffers..
scalar_t *params_grad_buf = x_residual_buf, // [N]
// This group of (threads_per_kernel_pos) threads is responsible *y_vals_grad_buf = output_grad_buf; // [N]
// for position (kh, kw) in the kernel; we iterate over the patch
// (an un-padded patch of output). if (threadIdx.x < N) {
scalar_t pos_add_val = pos_add_buf[pos_in_kernel], // There is an offset of 1 between the 'n' values and
pos_mul_val = pos_mul_buf[pos_in_kernel]; // the position in 'params'. To keep the backprop code similar to the CPU
// backprop code we restore that offset here, i.e. use the same layout
for (int pos_in_patch = threadIdx.x % threads_per_kernel_pos; // as the params.
pos_in_patch < patch_size; pos_in_patch += threads_per_kernel_pos) { params_grad_buf[threadIdx.x + 1] = this_params_grad;
// We are working out the contribution to the gradients for pos_add y_vals_grad_buf[threadIdx.x] = this_y_vals_grad;
// and pos_mul; we let `pos_in_patch` correspond to the *output*
// position, and work out the input position based on gthe kernel position.
int h_in_patch = pos_in_patch / patchW,
w_in_patch = pos_in_patch % patchW;
// pos_in_ppatch is the position in the padded patch corresponding to
// `pos_in_patch`.
int pos_in_ppatch = (h_in_patch + kH / 2) * ppatchW + (w_in_patch + kW / 2);
scalar_t dest_val = dest_img_buf[pos_in_ppatch];
int src_pos_in_ppatch = (h_in_patch + kh) * ppatchW + (w_in_patch + kw);
scalar_t src_val = src_img_buf[src_pos_in_ppatch];
scalar_t relu = dest_val + src_val + pos_add_val;
if (relu >= 0.0) {
scalar_t this_grad_output = grad_output_buf[pos_in_ppatch];
this_grad_pos_add += this_grad_output * pos_mul_val;
this_grad_pos_mul += this_grad_output * relu;
}
} }
this_grad_pos_add = tiled_warp_reduce_sum(
threads_per_kernel_pos, reduce_buf, this_grad_pos_add);
this_grad_pos_mul = tiled_warp_reduce_sum( // This next block does backprop relating to `y_vals`. Comparing with the CPU
threads_per_kernel_pos, reduce_buf, this_grad_pos_mul); // version (call this the "reference code") is the best way to understand this (this code is just a
if (threadIdx.x % threads_per_kernel_pos == 0) { // modification of that).
grad_pos_add_buf[pos_in_kernel] += this_grad_pos_add; {
grad_pos_mul_buf[pos_in_kernel] += this_grad_pos_mul; // Thread 0 is responsible for parts of the reference code that involve "sum_positive_grad";
// thread 64 is responsible for parts of the reference code that involve "sum_negative_grad";
scalar_t scale_grad = 0.0,
scale = params_buf[-2];
if (threadIdx.x == 0) {
scalar_t sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) {
// This is like the CPU code but with an offset of 1 for 'params_buf'
// versus 'params_a'.
params_grad_buf[1 + K + i] += sum_positive_grad * scale;
scale_grad += sum_positive_grad * params_buf[K + i];
sum_positive_grad += y_vals_grad_buf[K + i];
} }
params_grad_buf[0] += scale * scale_grad;
} else if (threadIdx.x == 64) {
scalar_t sum_negative_grad = y_vals_grad_buf[0];
for (int i = K - 1; i >= 0; i--) {
// This is like the CPU code but with an offset of 1 for 'params_buf'
// versus 'params_a'.
params_grad_buf[K - i] -= sum_negative_grad * scale;
scale_grad -= sum_negative_grad * params_buf[K - 1 - i];
sum_negative_grad += y_vals_grad_buf[K - i];
} }
} }
__syncthreads();
if (threadIdx.x == 64)
params_grad_buf[0] += scale * scale_grad;
__syncthreads();
} }
__syncthreads(); // make sure all threads have written to grad_pos_add_buf and if (threadIdx.x <= N) {
// grad_pos_mul_buf. params_grad[blockIdx.y][c][threadIdx.x] = params_grad_buf[threadIdx.x];
int block = blockIdx.z * gridDim.y + blockIdx.y;
int kernel_pos = threadIdx.x;
if (kernel_pos < (kH * kW)) {
int kh = kernel_pos / kW,
kw = kernel_pos % kW;
grad_pos_add[block][c][kh][kw] = grad_pos_add_buf[kernel_pos];
grad_pos_mul[block][c][kh][kw] = grad_pos_mul_buf[kernel_pos];
} }
} }
torch::Tensor learned_nonlin_cuda(torch::Tensor input, torch::Tensor learned_nonlin_cuda(torch::Tensor input,
torch::Tensor params) { torch::Tensor params) {
...@@ -556,9 +568,7 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input, ...@@ -556,9 +568,7 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
auto scalar_t = input.scalar_type(); auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device()); auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
// TODO: make this empty torch::Tensor output = torch::empty({B, C, T}, opts);
torch::Tensor output = torch::ones({B, C, T}, opts);
if (C * B * T == 0) if (C * B * T == 0)
return output; return output;
...@@ -592,6 +602,8 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input, ...@@ -592,6 +602,8 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
images_per_thread_block == 1, images_per_thread_block == 1,
"Code error"); "Code error");
TORCH_CHECK(N + 1 <= THREADS_PER_BLOCK,
"Values of N this large are not supported.");
dim3 gridDim(C, grid_dim_y, 1); dim3 gridDim(C, grid_dim_y, 1);
...@@ -610,165 +622,89 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input, ...@@ -610,165 +622,89 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input, std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
torch::Tensor params, torch::Tensor params,
torch::Tensor grad_output) { torch::Tensor output_grad) {
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
TORCH_CHECK(output_grad.dim() == 3 && output_grad.size(0) == input.size(0) &&
output_grad.size(1) == input.size(1) &&
output_grad.size(2) == input.size(2),
"output_grad and input have mismatched dim.");
/*
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor"); TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor");
const int N = input.size(0), TORCH_CHECK(output_grad.device().is_cuda(), "output_grad must be a CUDA tensor");
C = input.size(1) / 2, TORCH_CHECK(params.device().is_cuda(), "Params must be a CUDA tensor");
H = input.size(2),
W = input.size(3),
kH = pos_add.size(1), const int B = input.size(0),
kW = pos_add.size(2); C = input.size(1),
TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1); T = input.size(2),
TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels"); N = params.size(1) - 1;
TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
pos_mul.size(1) == kH && pos_mul.size(2) == kW, TORCH_CHECK(N >= 4, "This backward code requires N >= 4");
"Input sizes mismatch."); TORCH_CHECK(N <= 16, "This backward code currently requires N <= 16");
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
auto scalar_t = input.scalar_type(); auto scalar_t = input.scalar_type();
TORCH_CHECK(pos_add.scalar_type() == scalar_t && auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch");
TORCH_CHECK(grad_output.dim() == 4 && grad_output.size(0) == N
&& grad_output.size(1) == C && grad_output.size(2) == H
&& grad_output.size(3) == W);
// Work out the configuration to call the kernel with..
int patchH = std::min(H, kH), // output patch height
patchW = std::min(W, kW); // output patch width
// We don't want the height or width of the patch to be less than the kernel
// width, or the padding will make the input-patch size more than twice the
// output-patch size.
// We aim for the output-patch size to be more than 128; this is not something
// very exact, but it roughly corresponds to us wanting to have up to 4 threads
// per output pixel, and the limitation of 512 threads per thread-block which
// we impose so that we can run on architectures with little shared memory.
while (patchW < W && patchH * (patchW + 1) <= 128)
patchW++;
while(patchH < H && (patchH + 1) * patchW <= 128)
patchH++;
// We are assuming that the thread-block size can be as large as 512; this
// works even on old CUDA architectures.
int threads_per_pixel;
if (patchH * patchW * 4 <= 512 && (kH * kW) > 8)
threads_per_pixel = 4;
else if (patchH * patchW * 2 <= 512 && (kH * kW) > 4)
threads_per_pixel = 2;
else
threads_per_pixel = 1;
int threads_per_block = patchH * patchW * threads_per_pixel;
// round threads_per_block up to a multiple of 64. We need it to be
// equivalent to an even number of warps, because at one point we divide the
// threads into two halves and we want them to be an even number of warps.
threads_per_block = 64 * ((threads_per_block + 63) / 64);
{
// If it's possible to increase the patch width or height while not exceeding
// this number of threads, do so. (This is a small optimization).
int patchW_old = patchW;
while (patchH * (patchW + 1) * threads_per_pixel <= threads_per_block)
patchW++;
// If the above change to patchW did not actually reduce the number of patches
// needed to cover the image, gthen there is no point to the change; and it
// increases the shared-memory requirement, so revert it.
if ((W + patchW_old - 1) / patchW_old == (W + patchW - 1) / patchW)
patchW = patchW_old;
int patchH_old = patchH;
while ((patchH + 1) * patchW * threads_per_pixel <= threads_per_block)
patchH++;
if ((H + patchH_old - 1) / patchH_old == (H + patchH - 1) / patchH)
patchH = patchH_old;
}
torch::Tensor input_grad = torch::empty({B, C, T}, opts);
int threads_per_kernel_pos = 1; if (C * B * T == 0) {
while (threads_per_kernel_pos < 32 && return std::vector<torch::Tensor>({input_grad,
threads_per_kernel_pos * 2 * kH * kW <= threads_per_block) torch::empty({C, N + 1})});
threads_per_kernel_pos *= 2;
// dimensions of padded patches
int ppatchH = patchH + kH - 1,
ppatchW = patchW + kW - 1,
ppatch_size = ppatchH * ppatchW;
int buffer_numel = 4 * (kH * kW) + 3 * ppatch_size + threads_per_block;
int num_patches_H = (H + patchH - 1) / patchH,
num_patches_W = (W + patchW - 1) / patchW,
num_patches = num_patches_H * num_patches_W;
// gridDim.x == C.
int num_blocks_patch = 1, // gridDim.y. should not be more
num_blocks_batch = 1; // gridDim.z
// We have a rough target of no more than 256 thread-groups.
while (C * num_blocks_patch * 2 <= 256 &&
num_blocks_patch * 2 <= num_patches)
num_blocks_patch *= 2;
if (C * num_patches <= 512)
num_blocks_patch = num_patches;
while (C * num_blocks_patch * num_blocks_batch * 2 <= 256 &&
num_blocks_batch * 2 <= N)
num_blocks_batch *= 2;
assert(num_blocks_patch <= num_patches && num_blocks_batch <= N);
assert(patchH * patchW * threads_per_pixel <= threads_per_block);
assert(kH * kW * threads_per_kernel_pos <= threads_per_block);
static int debug_count = 50;
if (debug_count > 0) {
debug_count--;
std::cout << "[backward:] N,C,H,W=" << N << "," << C << "," << H << "," << W
<< "; kW,kH=" << kW << "," << kH
<< "; patchH,patchW=" << patchH << ","
<< patchW << ", num_blocks_patch="
<< num_blocks_patch << ", num_blocks_batch="
<< num_blocks_batch
<< ", threads_per_pixel=" << threads_per_pixel
<< ", threads_per_kernel_pos=" << threads_per_kernel_pos
<< ", threads_per_block=" << threads_per_block
<< ", buffer_numel=" << buffer_numel
<< std::endl;
} }
int num_blocks = num_blocks_patch * num_blocks_batch; int images_per_thread_block = 1;
while (images_per_thread_block * 2 * T <= THREADS_PER_BLOCK &&
torch::Tensor grad_input = torch::zeros({N, 2*C, H, W}, images_per_thread_block * 2 * N <= THREADS_PER_BLOCK)
torch::TensorOptions().dtype(scalar_t).device(input.device())), images_per_thread_block *= 2;
grad_pos_add = torch::zeros({num_blocks, C, kH, kW},
torch::TensorOptions().dtype(scalar_t).device(input.device())), int grid_dim_y = 1;
grad_pos_mul = torch::zeros({num_blocks, C, kH, kW}, // If the number of channels is quite small (<128) we can launch more thread
torch::TensorOptions().dtype(scalar_t).device(input.device())); // groups, splitting on the batch index.
while (C * grid_dim_y < 128)
grid_dim_y *= 2;
dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
// blockDim is scalar, just threads_per_block. // B_reduced is the max number of thread-groups per channel that would have
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel_backward", ([&] { // any work to do. If grid_dim_y is more than this, we reduce it to avoid
learned_nonlin_kernel_backward<scalar_t><<<gridDim, threads_per_block, // launching kernels with nothing to do.
sizeof(scalar_t) * buffer_numel, int B_reduced = (B + images_per_thread_block - 1) / images_per_thread_block;
at::cuda::getCurrentCUDAStream()>>>( if (grid_dim_y > B_reduced)
input.packed_accessor32<scalar_t, 4>(), grid_dim_y = B_reduced;
pos_add.packed_accessor32<scalar_t, 3>(),
pos_mul.packed_accessor32<scalar_t, 3>(), int shared_mem_numel = 2 * N + 3;
grad_output.packed_accessor32<scalar_t, 4>(),
grad_input.packed_accessor32<scalar_t, 4>(), if (false)
grad_pos_add.packed_accessor32<scalar_t, 4>(), std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N
grad_pos_mul.packed_accessor32<scalar_t, 4>(), << ", images_per_thread_block = " << images_per_thread_block
patchH, << ", grid_dim_y = " << grid_dim_y
patchW, << "\n";
threads_per_pixel,
threads_per_kernel_pos); TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= T ||
images_per_thread_block == 1,
"Code error");
torch::Tensor params_grad = torch::empty({grid_dim_y, C, N + 1}, opts);
dim3 gridDim(C, grid_dim_y, 1);
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_backward_kernel", ([&] {
learned_nonlin_backward_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 3>(),
params.packed_accessor32<scalar_t, 2>(),
output_grad.packed_accessor32<scalar_t, 3>(),
input_grad.packed_accessor32<scalar_t, 3>(),
params_grad.packed_accessor32<scalar_t, 3>(),
images_per_thread_block);
})); }));
grad_pos_add = at::sum(grad_pos_add, {0});
grad_pos_mul = at::sum(grad_pos_mul, {0});
return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul}); */ params_grad = at::sum(params_grad, {0});
return std::vector<torch::Tensor>({input_grad, params_grad});
} }
...@@ -63,7 +63,7 @@ def test_learned_nonlin_deriv(): ...@@ -63,7 +63,7 @@ def test_learned_nonlin_deriv():
device = torch.device('cuda:0') device = torch.device('cuda:0')
y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu')) y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same") print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06): if not torch.allclose(y, y2, atol=1.0e-05):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}, max-diff = {(y2-y).abs().max()}") print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}, max-diff = {(y2-y).abs().max()}")
assert(0) assert(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment