"cacheflow/vscode:/vscode.git/clone" did not exist on "0deacbce6e96a1af5885babc4e470ce2a0cecf95"
Commit c3e61bea authored by Daniel Povey's avatar Daniel Povey
Browse files

Nearly-working backprop on CUDA

parent e0bc4029
......@@ -260,6 +260,14 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
We also require that N <= THREADS_PER_BLOCK (for best performance,
N should be quite small, like no larger than 8 or so).
We also require 4 <= N <= 16 for this code!
And we require that
N <= (THREADS_PER_BLOCK / images_per_thread_block)
(both sides will be powers of 2).. this ensures that blocks of threads
summing the N values are always within the same image, which helps
avoid a problem where some loops over 'b' would be done earlier
than others, and we'd end up counting certain pixels twice as their
output_grad would stay nonzero.
*/
template <typename scalar_t>
......@@ -291,12 +299,13 @@ void learned_nonlin_backward_kernel(
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
scalar_t input_buf[THREADS_PER_BLOCK]; // input sequence
scalar_t output_grad_buf[THREADS_PER_BLOCK];
char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`, this stores
// the integer value 0 <= n < N which
// determines which piece of the piecewise
// linear function we are in.
__shared__ scalar_t input_buf[THREADS_PER_BLOCK]; // input sequence
__shared__ scalar_t output_grad_buf[THREADS_PER_BLOCK];
__shared__ char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`,
// this stores the integer value 0
// <= n < N which determines which
// piece of the piecewise linear
// function we are in.
// Load parameters
if (threadIdx.x <= N)
......@@ -352,7 +361,12 @@ void learned_nonlin_backward_kernel(
// will be set to zero for excess threads, and thus won't contribute to
// this_params_grad or this_y_vals_grad.
for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) {
int t = threadIdx.x % T_inc + t_offset;
// The following is equivalent to:
// int t = (threadIdx.x % T_inc) + t_offset;
// given that T_inc is a power of 2 and t_offset >= THREADS_PER_BLOCK >= T_inc.
int t = (threadIdx.x & (T_inc - 1)) | t_offset;
scalar_t this_output_grad = 0.0;
if (t < T)
this_output_grad = output_grad[b][c][t];
......@@ -373,25 +387,24 @@ void learned_nonlin_backward_kernel(
else if (x >= N) x = N - 1;
// C++ rounds toward zero.
int n = (int)x;
n_buf[threadIdx.x] = (char)n;
// OK, at this point, 0 <= min < N.
n_buf[threadIdx.x] = (char)n; // 0 <= n < N
// The forward code did:
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
if (t < T)
input_grad[b][c][t] = this_output_grad * params_buf[n];
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N),
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N),
// since N is power of 2
this_n = threadIdx.x & (N-1); // == threadIdx.x % N.
// this_n is the n value that this thread accumulates gradients for;
// it is responsible for output_grads in the block of threads
// from this_block_start to this_block_start+N-1.
// SYNC POINT At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is considered not
// __syncthreads(); // <- not really needed.
// At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is not considered
// future-proof). Threads above have written to n_buf, and threads below
// will read from it; but we don't need to explicitly synchronize for now
// because the reads/writes are among threads in a group of N threads with
......@@ -399,7 +412,7 @@ void learned_nonlin_backward_kernel(
// src_indexes will contain up to 16 16-bit numbers, stored starting in its
// least significant bits. It will store all the offsets within this
// block of N, where the 'n' value equals this_n.
// block of N threads, whose chosen 'n' value equals this_n.
uint64_t src_indexes = 0;
// num_src is the number of numbers in `src_indexes`. We need to store a
// separate counter because zero is a valid index and if we are to support
......@@ -407,11 +420,12 @@ void learned_nonlin_backward_kernel(
// of marker.
int num_src = 0;
// This loop always does N statements, but they should be relatively fast
// ones since the computation per n value is minimal and there is little
// I/O. We are figuring out the subset of our block of N elements,
// which this particular thread value is responsible for (because they
// have n == this_n), and storing them in `src_indexes` and `num_src`.
// This loop always does at least N statements, but they should be
// relatively fast ones since the computation per n value is minimal and
// there is little I/O. We are figuring out the subset of our block of N
// elements, which this particular thread value is responsible for
// (because they have n == this_n), and storing them in `src_indexes` and
// `num_src`.
for (int i = 0; i < N; i += 4) {
uint32_t n_block_of_4 = *reinterpret_cast<uint32_t*>(n_buf + this_block_start + i);
#pragma unroll
......@@ -438,38 +452,45 @@ void learned_nonlin_backward_kernel(
// number of images, and the hope is that different warps will reach the
// end of the outer loop at around the same time because their variations
// in speed will average out.
for (; num_src > 0; --num_src, src_indexes >>= 4) {
int src_idx = src_indexes & 0xF,
src_thread = this_block_start + src_idx;
scalar_t output_grad = output_grad_buf[src_thread],
this_input = input_buf[src_thread];
// Backprop for: output = x_residual * (params_buf[n] * scale) + y_vals[n].
for (; num_src > 0; --num_src, (src_indexes >>= 4)) {
int src_thread = this_block_start | (src_indexes & 0xF);
scalar_t src_output_grad = output_grad_buf[src_thread],
src_input = input_buf[src_thread];
assert(n_buf[src_thread] == this_n);
n_buf[src_thread] = 0;
// Backprop for: output = input * params_buf[n] + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_param_grad += output_grad * this_input;
this_y_vals_grad += output_grad;
this_param_grad += src_output_grad * src_input;
this_y_vals_grad += src_output_grad;
}
// TODO: remove the next lines
assert(n_buf[threadIdx.x] == 0);
output_grad_buf[threadIdx.x] = 0.0;
}
}
__syncthreads(); // sync threads because we are about to re-use
// output_grad_buf for reduction.
// output_grad_buf for reduction, and, later, input_buf.
this_param_grad = strided_reduce_sum(N, output_grad_buf, this_param_grad);
__syncthreads();
this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad);
__syncthreads(); // sync threads because we are about to re-use
// output_grad_buf.
// output_grad_buf as y_vals_grad_buf.
// Re-use some buffers..
scalar_t *params_grad_buf = input_buf + 1, // [N] ... but element [-1] will have deriv of scale.
*y_vals_grad_buf = output_grad_buf; // [N]
if (threadIdx.x < N) {
// Restore the indexing offset of 1 in params_grad_buf (versus
// params_buf
params_grad_buf[threadIdx.x] = this_param_grad;
y_vals_grad_buf[threadIdx.x] = this_y_vals_grad;
}
__syncthreads(); // other threads are about to read params_grad_buf and
// y_vals_grad_buf.
// This next block does backprop relating to `y_vals`. Comparing with the CPU
// version (call this the "reference code") is the best way to understand this
......@@ -479,7 +500,7 @@ void learned_nonlin_backward_kernel(
// the deriv of the log scale.
scalar_t l_grad;
if (threadIdx.x == 64) {
if (threadIdx.x == 0) {
// Now do the backprop for the loop above where we set y_vals_a. This could
// be further optimized to replace the loop with a raking, but I doubt this
// will have a huge effect on the runtime since K will be fairly small,
......@@ -499,9 +520,11 @@ void learned_nonlin_backward_kernel(
scale_grad += pos_scaled_param_grad * params_buf[K + i];
}
// Backprop for: scale = exp(l), where l = params[c][0].
params_grad_buf[-1] = scale * scale_grad;
} else if (threadIdx.x == 0) {
l_grad = scale * scale_grad;
} else if (threadIdx.x == 64) {
// Now do the backprop for the loop above where we set y_vals.
// Make this one threadIdx.x == 0 so it's possibly quicker to test
//
scalar_t scale = params_buf[-2],
scale_grad = 0.0,
sum_negative_grad = 0.0;
......@@ -516,14 +539,17 @@ void learned_nonlin_backward_kernel(
params_grad_buf[K - i - 1] += neg_scaled_param_grad * scale;
scale_grad += neg_scaled_param_grad * params_buf[K - i - 1];
}
l_grad = scale * scale_grad;
params_grad_buf[-1] = scale * scale_grad;
}
__syncthreads();
if (threadIdx.x == 0)
if (threadIdx.x == 0) {
params_grad_buf[-1] += l_grad; // contribution to l grad from the "negative" branch
}
__syncthreads();
if (threadIdx.x <= N)
if (threadIdx.x <= N) {
params_grad[blockIdx.y][c][threadIdx.x] = params_grad_buf[threadIdx.x - 1];
}
}
......@@ -623,7 +649,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
TORCH_CHECK(output_grad.device().is_cuda(), "output_grad must be a CUDA tensor");
TORCH_CHECK(params.device().is_cuda(), "Params must be a CUDA tensor");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
......@@ -631,6 +656,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
TORCH_CHECK(N >= 4, "This backward code requires N >= 4");
TORCH_CHECK(N <= 16, "This backward code currently requires N <= 16");
TORCH_CHECK((N & (N-1)) == 0, "N must be a power of 2")
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
......@@ -663,7 +689,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
int shared_mem_numel = 2 * N + 3;
if (false)
if (true)
std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N
<< ", images_per_thread_block = " << images_per_thread_block
<< ", grid_dim_y = " << grid_dim_y
......@@ -673,8 +701,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
images_per_thread_block == 1,
"Code error");
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= N);
torch::Tensor params_grad = torch::empty({grid_dim_y, C, N + 1}, opts);
torch::Tensor params_grad = torch::zeros({grid_dim_y, C, N + 1}, opts);
dim3 gridDim(C, grid_dim_y, 1);
......
......@@ -13,7 +13,7 @@ def test_learned_nonlin_basic():
K = 4
N = K * 2
params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1)
params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1) - 3
x.requires_grad = True
params.requires_grad = True
print("x = ", x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment