Commit c3e61bea authored by Daniel Povey's avatar Daniel Povey
Browse files

Nearly-working backprop on CUDA

parent e0bc4029
...@@ -260,6 +260,14 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N, ...@@ -260,6 +260,14 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
We also require that N <= THREADS_PER_BLOCK (for best performance, We also require that N <= THREADS_PER_BLOCK (for best performance,
N should be quite small, like no larger than 8 or so). N should be quite small, like no larger than 8 or so).
We also require 4 <= N <= 16 for this code! We also require 4 <= N <= 16 for this code!
And we require that
N <= (THREADS_PER_BLOCK / images_per_thread_block)
(both sides will be powers of 2).. this ensures that blocks of threads
summing the N values are always within the same image, which helps
avoid a problem where some loops over 'b' would be done earlier
than others, and we'd end up counting certain pixels twice as their
output_grad would stay nonzero.
*/ */
template <typename scalar_t> template <typename scalar_t>
...@@ -291,12 +299,13 @@ void learned_nonlin_backward_kernel( ...@@ -291,12 +299,13 @@ void learned_nonlin_backward_kernel(
// params_buf[-1] contains params[c][0] == log of scale; // params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale. // params_buf[-2] and params_buf[-3] contain scale and inv_scale.
scalar_t input_buf[THREADS_PER_BLOCK]; // input sequence __shared__ scalar_t input_buf[THREADS_PER_BLOCK]; // input sequence
scalar_t output_grad_buf[THREADS_PER_BLOCK]; __shared__ scalar_t output_grad_buf[THREADS_PER_BLOCK];
char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`, this stores __shared__ char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`,
// the integer value 0 <= n < N which // this stores the integer value 0
// determines which piece of the piecewise // <= n < N which determines which
// linear function we are in. // piece of the piecewise linear
// function we are in.
// Load parameters // Load parameters
if (threadIdx.x <= N) if (threadIdx.x <= N)
...@@ -352,7 +361,12 @@ void learned_nonlin_backward_kernel( ...@@ -352,7 +361,12 @@ void learned_nonlin_backward_kernel(
// will be set to zero for excess threads, and thus won't contribute to // will be set to zero for excess threads, and thus won't contribute to
// this_params_grad or this_y_vals_grad. // this_params_grad or this_y_vals_grad.
for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) { for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) {
int t = threadIdx.x % T_inc + t_offset;
// The following is equivalent to:
// int t = (threadIdx.x % T_inc) + t_offset;
// given that T_inc is a power of 2 and t_offset >= THREADS_PER_BLOCK >= T_inc.
int t = (threadIdx.x & (T_inc - 1)) | t_offset;
scalar_t this_output_grad = 0.0; scalar_t this_output_grad = 0.0;
if (t < T) if (t < T)
this_output_grad = output_grad[b][c][t]; this_output_grad = output_grad[b][c][t];
...@@ -373,10 +387,7 @@ void learned_nonlin_backward_kernel( ...@@ -373,10 +387,7 @@ void learned_nonlin_backward_kernel(
else if (x >= N) x = N - 1; else if (x >= N) x = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int)x; int n = (int)x;
n_buf[threadIdx.x] = (char)n; n_buf[threadIdx.x] = (char)n; // 0 <= n < N
// OK, at this point, 0 <= min < N.
// The forward code did: // The forward code did:
// output[b][c][t] = this_input * params_buf[n] + y_vals[n]; // output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later. // We get the derivative for params and y_vals later.
...@@ -384,14 +395,16 @@ void learned_nonlin_backward_kernel( ...@@ -384,14 +395,16 @@ void learned_nonlin_backward_kernel(
input_grad[b][c][t] = this_output_grad * params_buf[n]; input_grad[b][c][t] = this_output_grad * params_buf[n];
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N), int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N),
// since N is power of 2
this_n = threadIdx.x & (N-1); // == threadIdx.x % N. this_n = threadIdx.x & (N-1); // == threadIdx.x % N.
// this_n is the n value that this thread accumulates gradients for; // this_n is the n value that this thread accumulates gradients for;
// it is responsible for output_grads in the block of threads // it is responsible for output_grads in the block of threads
// from this_block_start to this_block_start+N-1. // from this_block_start to this_block_start+N-1.
// SYNC POINT At this point there is an implicit within-warp // __syncthreads(); // <- not really needed.
// synchronization (Note: implicit warp synchronization is considered not // At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is not considered
// future-proof). Threads above have written to n_buf, and threads below // future-proof). Threads above have written to n_buf, and threads below
// will read from it; but we don't need to explicitly synchronize for now // will read from it; but we don't need to explicitly synchronize for now
// because the reads/writes are among threads in a group of N threads with // because the reads/writes are among threads in a group of N threads with
...@@ -399,7 +412,7 @@ void learned_nonlin_backward_kernel( ...@@ -399,7 +412,7 @@ void learned_nonlin_backward_kernel(
// src_indexes will contain up to 16 16-bit numbers, stored starting in its // src_indexes will contain up to 16 16-bit numbers, stored starting in its
// least significant bits. It will store all the offsets within this // least significant bits. It will store all the offsets within this
// block of N, where the 'n' value equals this_n. // block of N threads, whose chosen 'n' value equals this_n.
uint64_t src_indexes = 0; uint64_t src_indexes = 0;
// num_src is the number of numbers in `src_indexes`. We need to store a // num_src is the number of numbers in `src_indexes`. We need to store a
// separate counter because zero is a valid index and if we are to support // separate counter because zero is a valid index and if we are to support
...@@ -407,11 +420,12 @@ void learned_nonlin_backward_kernel( ...@@ -407,11 +420,12 @@ void learned_nonlin_backward_kernel(
// of marker. // of marker.
int num_src = 0; int num_src = 0;
// This loop always does N statements, but they should be relatively fast // This loop always does at least N statements, but they should be
// ones since the computation per n value is minimal and there is little // relatively fast ones since the computation per n value is minimal and
// I/O. We are figuring out the subset of our block of N elements, // there is little I/O. We are figuring out the subset of our block of N
// which this particular thread value is responsible for (because they // elements, which this particular thread value is responsible for
// have n == this_n), and storing them in `src_indexes` and `num_src`. // (because they have n == this_n), and storing them in `src_indexes` and
// `num_src`.
for (int i = 0; i < N; i += 4) { for (int i = 0; i < N; i += 4) {
uint32_t n_block_of_4 = *reinterpret_cast<uint32_t*>(n_buf + this_block_start + i); uint32_t n_block_of_4 = *reinterpret_cast<uint32_t*>(n_buf + this_block_start + i);
#pragma unroll #pragma unroll
...@@ -438,38 +452,45 @@ void learned_nonlin_backward_kernel( ...@@ -438,38 +452,45 @@ void learned_nonlin_backward_kernel(
// number of images, and the hope is that different warps will reach the // number of images, and the hope is that different warps will reach the
// end of the outer loop at around the same time because their variations // end of the outer loop at around the same time because their variations
// in speed will average out. // in speed will average out.
for (; num_src > 0; --num_src, src_indexes >>= 4) { for (; num_src > 0; --num_src, (src_indexes >>= 4)) {
int src_idx = src_indexes & 0xF, int src_thread = this_block_start | (src_indexes & 0xF);
src_thread = this_block_start + src_idx; scalar_t src_output_grad = output_grad_buf[src_thread],
scalar_t output_grad = output_grad_buf[src_thread], src_input = input_buf[src_thread];
this_input = input_buf[src_thread]; assert(n_buf[src_thread] == this_n);
// Backprop for: output = x_residual * (params_buf[n] * scale) + y_vals[n]. n_buf[src_thread] = 0;
// Backprop for: output = input * params_buf[n] + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values. // Here, n == this_n; this is how we selected these `src_idx` values.
this_param_grad += output_grad * this_input; this_param_grad += src_output_grad * src_input;
this_y_vals_grad += output_grad; this_y_vals_grad += src_output_grad;
} }
// TODO: remove the next lines
assert(n_buf[threadIdx.x] == 0);
output_grad_buf[threadIdx.x] = 0.0;
} }
} }
__syncthreads(); // sync threads because we are about to re-use __syncthreads(); // sync threads because we are about to re-use
// output_grad_buf for reduction. // output_grad_buf for reduction, and, later, input_buf.
this_param_grad = strided_reduce_sum(N, output_grad_buf, this_param_grad); this_param_grad = strided_reduce_sum(N, output_grad_buf, this_param_grad);
__syncthreads();
this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad); this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad);
__syncthreads(); // sync threads because we are about to re-use __syncthreads(); // sync threads because we are about to re-use
// output_grad_buf. // output_grad_buf as y_vals_grad_buf.
// Re-use some buffers.. // Re-use some buffers..
scalar_t *params_grad_buf = input_buf + 1, // [N] ... but element [-1] will have deriv of scale. scalar_t *params_grad_buf = input_buf + 1, // [N] ... but element [-1] will have deriv of scale.
*y_vals_grad_buf = output_grad_buf; // [N] *y_vals_grad_buf = output_grad_buf; // [N]
if (threadIdx.x < N) { if (threadIdx.x < N) {
// Restore the indexing offset of 1 in params_grad_buf (versus
// params_buf
params_grad_buf[threadIdx.x] = this_param_grad; params_grad_buf[threadIdx.x] = this_param_grad;
y_vals_grad_buf[threadIdx.x] = this_y_vals_grad; y_vals_grad_buf[threadIdx.x] = this_y_vals_grad;
} }
__syncthreads(); // other threads are about to read params_grad_buf and
// y_vals_grad_buf.
// This next block does backprop relating to `y_vals`. Comparing with the CPU // This next block does backprop relating to `y_vals`. Comparing with the CPU
// version (call this the "reference code") is the best way to understand this // version (call this the "reference code") is the best way to understand this
...@@ -479,7 +500,7 @@ void learned_nonlin_backward_kernel( ...@@ -479,7 +500,7 @@ void learned_nonlin_backward_kernel(
// the deriv of the log scale. // the deriv of the log scale.
scalar_t l_grad; scalar_t l_grad;
if (threadIdx.x == 64) { if (threadIdx.x == 0) {
// Now do the backprop for the loop above where we set y_vals_a. This could // Now do the backprop for the loop above where we set y_vals_a. This could
// be further optimized to replace the loop with a raking, but I doubt this // be further optimized to replace the loop with a raking, but I doubt this
// will have a huge effect on the runtime since K will be fairly small, // will have a huge effect on the runtime since K will be fairly small,
...@@ -499,9 +520,11 @@ void learned_nonlin_backward_kernel( ...@@ -499,9 +520,11 @@ void learned_nonlin_backward_kernel(
scale_grad += pos_scaled_param_grad * params_buf[K + i]; scale_grad += pos_scaled_param_grad * params_buf[K + i];
} }
// Backprop for: scale = exp(l), where l = params[c][0]. // Backprop for: scale = exp(l), where l = params[c][0].
params_grad_buf[-1] = scale * scale_grad; l_grad = scale * scale_grad;
} else if (threadIdx.x == 0) { } else if (threadIdx.x == 64) {
// Now do the backprop for the loop above where we set y_vals. // Now do the backprop for the loop above where we set y_vals.
// Make this one threadIdx.x == 0 so it's possibly quicker to test
//
scalar_t scale = params_buf[-2], scalar_t scale = params_buf[-2],
scale_grad = 0.0, scale_grad = 0.0,
sum_negative_grad = 0.0; sum_negative_grad = 0.0;
...@@ -516,14 +539,17 @@ void learned_nonlin_backward_kernel( ...@@ -516,14 +539,17 @@ void learned_nonlin_backward_kernel(
params_grad_buf[K - i - 1] += neg_scaled_param_grad * scale; params_grad_buf[K - i - 1] += neg_scaled_param_grad * scale;
scale_grad += neg_scaled_param_grad * params_buf[K - i - 1]; scale_grad += neg_scaled_param_grad * params_buf[K - i - 1];
} }
l_grad = scale * scale_grad; params_grad_buf[-1] = scale * scale_grad;
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0)
if (threadIdx.x == 0) {
params_grad_buf[-1] += l_grad; // contribution to l grad from the "negative" branch params_grad_buf[-1] += l_grad; // contribution to l grad from the "negative" branch
}
__syncthreads(); __syncthreads();
if (threadIdx.x <= N) if (threadIdx.x <= N) {
params_grad[blockIdx.y][c][threadIdx.x] = params_grad_buf[threadIdx.x - 1]; params_grad[blockIdx.y][c][threadIdx.x] = params_grad_buf[threadIdx.x - 1];
}
} }
...@@ -623,7 +649,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input, ...@@ -623,7 +649,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
TORCH_CHECK(output_grad.device().is_cuda(), "output_grad must be a CUDA tensor"); TORCH_CHECK(output_grad.device().is_cuda(), "output_grad must be a CUDA tensor");
TORCH_CHECK(params.device().is_cuda(), "Params must be a CUDA tensor"); TORCH_CHECK(params.device().is_cuda(), "Params must be a CUDA tensor");
const int B = input.size(0), const int B = input.size(0),
C = input.size(1), C = input.size(1),
T = input.size(2), T = input.size(2),
...@@ -631,6 +656,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input, ...@@ -631,6 +656,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
TORCH_CHECK(N >= 4, "This backward code requires N >= 4"); TORCH_CHECK(N >= 4, "This backward code requires N >= 4");
TORCH_CHECK(N <= 16, "This backward code currently requires N <= 16"); TORCH_CHECK(N <= 16, "This backward code currently requires N <= 16");
TORCH_CHECK((N & (N-1)) == 0, "N must be a power of 2")
auto scalar_t = input.scalar_type(); auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device()); auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
...@@ -663,7 +689,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input, ...@@ -663,7 +689,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
int shared_mem_numel = 2 * N + 3; int shared_mem_numel = 2 * N + 3;
if (false)
if (true)
std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N
<< ", images_per_thread_block = " << images_per_thread_block << ", images_per_thread_block = " << images_per_thread_block
<< ", grid_dim_y = " << grid_dim_y << ", grid_dim_y = " << grid_dim_y
...@@ -673,8 +701,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input, ...@@ -673,8 +701,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
images_per_thread_block == 1, images_per_thread_block == 1,
"Code error"); "Code error");
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= N);
torch::Tensor params_grad = torch::empty({grid_dim_y, C, N + 1}, opts); torch::Tensor params_grad = torch::zeros({grid_dim_y, C, N + 1}, opts);
dim3 gridDim(C, grid_dim_y, 1); dim3 gridDim(C, grid_dim_y, 1);
......
...@@ -13,7 +13,7 @@ def test_learned_nonlin_basic(): ...@@ -13,7 +13,7 @@ def test_learned_nonlin_basic():
K = 4 K = 4
N = K * 2 N = K * 2
params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1) params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1) - 3
x.requires_grad = True x.requires_grad = True
params.requires_grad = True params.requires_grad = True
print("x = ", x) print("x = ", x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment