Commit e0bc4029 authored by Daniel Povey's avatar Daniel Povey
Browse files

CUDA backward running (but not correctly for params grad)

parent 2ccbb505
...@@ -124,9 +124,6 @@ void learned_nonlin_kernel( ...@@ -124,9 +124,6 @@ void learned_nonlin_kernel(
} }
__syncthreads(); __syncthreads();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
scalar_t scale = params_buf[-2], scalar_t scale = params_buf[-2],
sum_positive = 0.0; sum_positive = 0.0;
...@@ -294,65 +291,51 @@ void learned_nonlin_backward_kernel( ...@@ -294,65 +291,51 @@ void learned_nonlin_backward_kernel(
// params_buf[-1] contains params[c][0] == log of scale; // params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale. // params_buf[-2] and params_buf[-3] contain scale and inv_scale.
scalar_t x_residual_buf[THREADS_PER_BLOCK]; // x_residual, with 0 <= scalar_t input_buf[THREADS_PER_BLOCK]; // input sequence
// x_residual < 1 for interior
// regions, is the residual part
// of the scaled input, after
// subtracting the integer part.
scalar_t output_grad_buf[THREADS_PER_BLOCK]; scalar_t output_grad_buf[THREADS_PER_BLOCK];
char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`, this stores char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`, this stores
// the integer value 0 <= n < N which // the integer value 0 <= n < N which
// determines which piece of the piecewise // determines which piece of the piecewise
// linear function we are in. // linear function we are in.
// Load parameters // Load parameters
if (threadIdx.x <= N) if (threadIdx.x <= N)
params_buf[threadIdx.x - 1] = params[c][threadIdx.x]; params_buf[threadIdx.x - 1] = params[c][threadIdx.x];
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
scalar_t scale = exp(params_buf[-1]), scalar_t scale = exp(params_buf[-1]);
inv_scale = 1.0 / scale;
params_buf[-2] = scale; params_buf[-2] = scale;
params_buf[-3] = inv_scale; params_buf[-3] = 1.0 / scale;
} }
__syncthreads(); __syncthreads();
scalar_t scale = params_buf[-2];
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
scalar_t sum_positive = 0.0; scalar_t scale = params_buf[-2],
sum_positive = 0.0;
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals[K + i] = sum_positive; // params_buf is indexed with an index one less than params.
// versus the CPU code, the params_buf is indexed off by 1; and it already scalar_t pos_scaled_param = params_buf[K + i] * scale;
// contains the factor "scale". y_vals[K + i] = sum_positive - pos_scaled_param * i;
sum_positive += params_buf[K + i] * scale; sum_positive += pos_scaled_param;
} }
} else if (threadIdx.x == 64) { } else if (threadIdx.x == 64) {
scalar_t sum_negative = 0.0; scalar_t scale = params_buf[-2],
sum_negative = 0.0;
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals[K - i] = sum_negative; scalar_t neg_scaled_param = params_buf[K - i - 1] * scale;
// versus the CPU code, the params_buf is indexed off by 1; and it already sum_negative -= neg_scaled_param;
// contains the factor "scale". y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
sum_negative -= params_buf[K - 1 - i] * scale;
} }
y_vals[0] = sum_negative;
} }
__syncthreads(); __syncthreads();
// this_params_grad and this_y_grad pertain to the 'n' value (i.e. the n'th // this_param_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if // linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most // threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval. // linear interval.
// "this_params_grad" actually contains the derivative w.r.t. scaled params, i.e. scalar_t this_param_grad = 0.0,
// params[n] * scale.
scalar_t this_scaled_param_grad = 0.0,
this_y_vals_grad = 0.0; this_y_vals_grad = 0.0;
scalar_t inv_scale = params_buf[-3]; scalar_t inv_scale = params_buf[-3];
...@@ -370,7 +353,7 @@ void learned_nonlin_backward_kernel( ...@@ -370,7 +353,7 @@ void learned_nonlin_backward_kernel(
// this_params_grad or this_y_vals_grad. // this_params_grad or this_y_vals_grad.
for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) { for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) {
int t = threadIdx.x % T_inc + t_offset; int t = threadIdx.x % T_inc + t_offset;
scalar_t this_output_grad = 0.0, x = 0.0; scalar_t this_output_grad = 0.0;
if (t < T) if (t < T)
this_output_grad = output_grad[b][c][t]; this_output_grad = output_grad[b][c][t];
...@@ -381,29 +364,24 @@ void learned_nonlin_backward_kernel( ...@@ -381,29 +364,24 @@ void learned_nonlin_backward_kernel(
// the same 'n'. It might be better to set n to an invalid value for // the same 'n'. It might be better to set n to an invalid value for
// out-of-range threads, but as it is, if we are to properly handle // out-of-range threads, but as it is, if we are to properly handle
// N==16 we don't have enough bits available in `src_indexes` to do this. // N==16 we don't have enough bits available in `src_indexes` to do this.
x = input[b][c][t % T] * inv_scale + K; scalar_t this_input = input[b][c][t % T] * inv_scale + K;
input_buf[threadIdx.x] = this_input;
output_grad_buf[threadIdx.x] = this_output_grad; output_grad_buf[threadIdx.x] = this_output_grad;
scalar_t x_trunc = x; scalar_t x = this_input;
if (x_trunc < 0) x_trunc = 0; if (x < 0) x = 0;
else if (x_trunc >= N) x_trunc = N - 1; else if (x >= N) x = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int)x_trunc; int n = (int)x;
n_buf[threadIdx.x] = (char)n; n_buf[threadIdx.x] = (char)n;
scalar_t x_residual = x - n;
x_residual_buf[threadIdx.x] = x_residual;
// OK, at this point, 0 <= min < N. // OK, at this point, 0 <= min < N.
// The forward code did: // The forward code did:
// output[b][c][t] = (x - n) * params_buf[n] + y_vals[n]; // output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
if (t < T) { if (t < T)
// In a sense this expression should contain "* inv_scale * scale"...
// of course, their product equals 1.
input_grad[b][c][t] = this_output_grad * params_buf[n]; input_grad[b][c][t] = this_output_grad * params_buf[n];
}
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N), int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N),
this_n = threadIdx.x & (N-1); // == threadIdx.x % N. this_n = threadIdx.x & (N-1); // == threadIdx.x % N.
...@@ -464,10 +442,10 @@ void learned_nonlin_backward_kernel( ...@@ -464,10 +442,10 @@ void learned_nonlin_backward_kernel(
int src_idx = src_indexes & 0xF, int src_idx = src_indexes & 0xF,
src_thread = this_block_start + src_idx; src_thread = this_block_start + src_idx;
scalar_t output_grad = output_grad_buf[src_thread], scalar_t output_grad = output_grad_buf[src_thread],
x_residual = x_residual_buf[src_thread]; this_input = input_buf[src_thread];
// Backprop for: output = x_residual * (params_buf[n] * scale) + y_vals[n]. // Backprop for: output = x_residual * (params_buf[n] * scale) + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values. // Here, n == this_n; this is how we selected these `src_idx` values.
this_scaled_param_grad += output_grad * x_residual; this_param_grad += output_grad * this_input;
this_y_vals_grad += output_grad; this_y_vals_grad += output_grad;
} }
} }
...@@ -476,80 +454,76 @@ void learned_nonlin_backward_kernel( ...@@ -476,80 +454,76 @@ void learned_nonlin_backward_kernel(
__syncthreads(); // sync threads because we are about to re-use __syncthreads(); // sync threads because we are about to re-use
// output_grad_buf for reduction. // output_grad_buf for reduction.
this_scaled_param_grad = strided_reduce_sum(N, output_grad_buf, this_scaled_param_grad); this_param_grad = strided_reduce_sum(N, output_grad_buf, this_param_grad);
this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad); this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad);
__syncthreads(); // sync threads because we are about to re-use __syncthreads(); // sync threads because we are about to re-use
// output_grad_buf. // output_grad_buf.
// Re-use some buffers.. // Re-use some buffers..
scalar_t *scaled_params_grad_buf = x_residual_buf, // [N] ... a scalar_t *params_grad_buf = input_buf + 1, // [N] ... but element [-1] will have deriv of scale.
*y_vals_grad_buf = output_grad_buf; // [N] *y_vals_grad_buf = output_grad_buf; // [N]
if (threadIdx.x < N) { if (threadIdx.x < N) {
// There is an offset of 1 between the 'n' values and // Restore the indexing offset of 1 in params_grad_buf (versus
// the position in 'params'. To keep the backprop code similar to the CPU // params_buf
// backprop code we restore that offset here, i.e. use the same layout params_grad_buf[threadIdx.x] = this_param_grad;
// as the params.
scaled_params_grad_buf[threadIdx.x] = this_scaled_param_grad;
y_vals_grad_buf[threadIdx.x] = this_y_vals_grad; y_vals_grad_buf[threadIdx.x] = this_y_vals_grad;
} }
// This next block does backprop relating to `y_vals`. Comparing with the CPU // This next block does backprop relating to `y_vals`. Comparing with the CPU
// version (call this the "reference code") is the best way to understand this (this code is just a // version (call this the "reference code") is the best way to understand this
// modification of that). // (this code is just a modification of that). The main difference is we
{ // modify the indexes into params and params_grad by -1, so the index
// Thread 0 is responsible for parts of the reference code that involve "sum_positive_grad"; // corresponds to the 'n' value; and element -1 of params_grad_buf will have
// thread 64 is responsible for parts of the reference code that involve "sum_negative_grad"; // the deriv of the log scale.
scalar_t scale_grad = 0.0,
scale = params_buf[-2]; scalar_t l_grad;
if (threadIdx.x == 64) {
if (threadIdx.x == 0) { // Now do the backprop for the loop above where we set y_vals_a. This could
scalar_t sum_positive_grad = 0.0; // be further optimized to replace the loop with a raking, but I doubt this
for (int i = K - 1; i >= 0; i--) { // will have a huge effect on the runtime since K will be fairly small,
// This is like the CPU code but with an offset of -1 for indexes into 'params_buf'; // e.g. 4.
// also there is no scale because we are dealing with pre-scaled parameters. scalar_t scale = params_buf[-2],
scaled_params_grad_buf[K + i] += sum_positive_grad; scale_grad = 0.0,
sum_positive_grad += y_vals_grad_buf[K + i]; sum_positive_grad = 0.0;
} for (int i = K - 1; i >= 0; i--) {
} else if (threadIdx.x == 64) { // Backprop for: sum_positive += pos_scaled_param;
scalar_t sum_negative_grad = y_vals_grad_buf[0]; scalar_t pos_scaled_param_grad = sum_positive_grad;
for (int i = K - 1; i >= 0; i--) { // Backprop for: y_vals[K + i] = sum_positive - pos_scaled_param * i;
// This is like the CPU code but with an offset of 1 for 'params_buf'; scalar_t y_grad_pos = y_vals_grad_buf[K + i];
// also there is no scale because we are dealing with pre-scaled parameters. pos_scaled_param_grad -= i * y_grad_pos;
scaled_params_grad_buf[K - 1 - i] -= sum_negative_grad; sum_positive_grad += y_grad_pos;
sum_negative_grad += y_vals_grad_buf[K - i]; // Backprop for: pos_scaled_param = params_buf[K + i] * scale,
} params_grad_buf[K + i] += pos_scaled_param_grad * scale;
scale_grad += pos_scaled_param_grad * params_buf[K + i];
} }
__syncthreads(); // Backprop for: scale = exp(l), where l = params[c][0].
params_grad_buf[-1] = scale * scale_grad;
if (threadIdx.x < N) { } else if (threadIdx.x == 0) {
// this_scaled_param_grad is the gradient w.r.t. params_buf[n] * scale // Now do the backprop for the loop above where we set y_vals.
// which is equal to params[c][n + 1] * scale. scalar_t scale = params_buf[-2],
int n = threadIdx.x; scale_grad = 0.0,
scalar_t this_scaled_param_grad = scaled_params_grad_buf[n], sum_negative_grad = 0.0;
this_scale_grad = this_scaled_param_grad * params_buf[n], for (int i = K - 1; i >= 0; i--) {
scale = params_buf[-2], // Backprop for: y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
this_param_grad = scale * this_scaled_param_grad; scalar_t y_grad_neg = y_vals_grad_buf[K - i - 1];
sum_negative_grad += y_grad_neg;
// re-use x_residual_buf as 'param_grad_buf'. scalar_t neg_scaled_param_grad = y_grad_neg * (i + 1);
x_residual_buf[n + 1] = this_param_grad; // Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad -= sum_negative_grad;
scalar_t scale_grad = tiled_warp_reduce_sum(N, y_vals_grad_buf, this_scale_grad); // Backprop for: neg_scaled_param = params_buf[K - i - 1] * scale;
params_grad_buf[K - i - 1] += neg_scaled_param_grad * scale;
if (threadIdx.x == 0) scale_grad += neg_scaled_param_grad * params_buf[K - i - 1];
x_residual_buf[0] = scale_grad * scale; // deriv w.r.t. l.
} }
__syncthreads(); l_grad = scale * scale_grad;
}
if (threadIdx.x <= N) {
// note, we are re-using x_residual_buf for the params_grad.
params_grad[blockIdx.y][c][threadIdx.x] = x_residual_buf[threadIdx.x];
} }
__syncthreads();
if (threadIdx.x == 0)
params_grad_buf[-1] += l_grad; // contribution to l grad from the "negative" branch
__syncthreads();
if (threadIdx.x <= N)
params_grad[blockIdx.y][c][threadIdx.x] = params_grad_buf[threadIdx.x - 1];
} }
......
...@@ -22,17 +22,31 @@ def test_learned_nonlin_basic(): ...@@ -22,17 +22,31 @@ def test_learned_nonlin_basic():
y = learned_nonlin(x, params, dim = 1) y = learned_nonlin(x, params, dim = 1)
print("y = ", y) print("y = ", y)
y.sum().backward()
if torch.cuda.is_available(): if torch.cuda.is_available():
# test that the CUDA forward is the same as the CPU forward. # test that the CUDA forward is the same as the CPU forward.
device = torch.device('cuda:0') device = torch.device('cuda:0')
y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu')) x2 = x.to(device).detach()
x2.requires_grad = True
params2 = params.to(device).detach()
params2.requires_grad = True
y2 = learned_nonlin(x2, params2, dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same") print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06): if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}") print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
assert(0); assert(0);
y.sum().backward() y2.sum().backward()
if not torch.allclose(x.grad, x2.grad.to('cpu'), atol=1.0e-06):
print(f"Error: CPU x-grad versus CUDA grad not the same: {x.grad} vs. {x2.grad}, diff = {x2.grad.to('cpu')-x.grad}")
assert(0);
if not torch.allclose(params.grad, params2.grad.to('cpu'), atol=1.0e-06):
print(f"Error: CPU params-grad versus CUDA grad not the same: {params.grad} vs. {params2.grad}, diff = {params2.grad.to('cpu')-params.grad}")
assert(0);
print("x.grad = ", x.grad) print("x.grad = ", x.grad)
print("params.grad = ", params.grad) print("params.grad = ", params.grad)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment