Commit f4081496 authored by Daniel Povey's avatar Daniel Povey
Browse files

Work on backward kernel

parent 6a77cb45
......@@ -180,7 +180,6 @@ void learned_nonlin_kernel(
output[b][c][t] = (x - n) * params_buf[n] + y_vals[n];
}
}
}
......@@ -297,7 +296,9 @@ void learned_nonlin_backward_kernel(
// spaces between here and
// `params_buf` for storing scale
// and inv_scale and l == params[c][0].
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Caution: contains params[c][1] through params[c][N].
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Contains parameters (not times scale!)
// Caution: contains params[c][1] through params[c][N],
// i.e. numbering is off by 1 versus params.
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
......@@ -312,58 +313,56 @@ void learned_nonlin_backward_kernel(
// determines which piece of the piecewise
// linear function we are in.
// this_params_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval.
scalar_t this_params_grad = 0.0,
this_y_vals_grad = 0.0;
// Load parameters
if (threadIdx.x <= N)
params_buf[threadIdx.x - 1] = params[c][threadIdx.x];
__syncthreads();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
// This next block computes `y_vals`.
if ((((int)threadIdx.x & ~(int)32)) == 0) {
// threadIdx.x == 0 or 32. These are in separate warps so we can
// allow them to do separate jobs. This code takes linear time in K which
// is not at all ideal and could be improved if K is largish, but it shouldn't
// dominate the total time taken if we are processing a lot of data;
// and anyway, we doubt that K will be need to be more than 4 or 8 or so,
// so the potential savings are quite small.
if (threadIdx.x == 0) {
scalar_t scale = exp(params_buf[-1]),
inv_scale = 1.0 / scale;
params_buf[-2] = scale; // both threads write these but it's OK, it's the
// same value.
params_buf[-2] = scale;
params_buf[-3] = inv_scale;
int sign,
Koffset; // Koffset == K for threads handling sum_positive and K - 1
// for threads handling sum_negative, see
// learned_nonlin_cpu.cpp for reference code. This would be K
// + 1 and K respectively, except our params_buf has its index
// shifted by one versus params.
if (threadIdx.x == 0) { // sum_positive
sign = 1;
Koffset = K;
} else { // threadIdx.x == 32. sum_negative.
scale *= -1; // this is a local variable..
sign = -1;
Koffset = K - 1;
}
scalar_t sum = 0.0;
__syncthreads();
scalar_t scale = params_buf[-2];
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
if (threadIdx.x == 0) {
scalar_t sum_positive = 0.0;
for (int i = 0; i < K; i++) {
y_vals[K + i] = sum_positive;
// versus the CPU code, the params_buf is indexed off by 1; and it already
// contains the factor "scale".
sum_positive += params_buf[K + i] * scale;
}
} else if (threadIdx.x == 64) {
scalar_t sum_negative = 0.0;
for (int i = 0; i < K; i++) {
int isign = i * sign;
y_vals[K + isign] = sum * scale;
sum += params_buf[Koffset + isign];
y_vals[K - i] = sum_negative;
// versus the CPU code, the params_buf is indexed off by 1; and it already
// contains the factor "scale".
sum_negative -= params_buf[K - 1 - i] * scale;
}
if (threadIdx.x != 0) // sum_negative
y_vals[0] = sum * scale;
y_vals[0] = sum_negative;
}
__syncthreads();
// this_params_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval.
// "this_params_grad" actually contains the derivative w.r.t. scaled params, i.e.
// params[n] * scale.
scalar_t this_scaled_param_grad = 0.0,
this_y_vals_grad = 0.0;
scalar_t inv_scale = params_buf[-3];
int T_inc = THREADS_PER_BLOCK / images_per_thread_block,
......@@ -408,8 +407,11 @@ void learned_nonlin_backward_kernel(
// The forward code did:
// output[b][c][t] = (x - n) * params_buf[n] + y_vals[n];
if (t < T)
if (t < T) {
// In a sense this expression should contain "* inv_scale * scale"...
// of course, their product equals 1.
input_grad[b][c][t] = this_output_grad * params_buf[n];
}
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N),
this_n = threadIdx.x & (N-1); // == threadIdx.x % N.
......@@ -471,9 +473,9 @@ void learned_nonlin_backward_kernel(
src_thread = this_block_start + src_idx;
scalar_t output_grad = output_grad_buf[src_thread],
x_residual = x_residual_buf[src_thread];
// Backprop for: output = x_residual * params_buf[n] + y_vals[n].
// Backprop for: output = x_residual * (params_buf[n] * scale) + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_params_grad += output_grad * x_residual;
this_scaled_param_grad += output_grad * x_residual;
this_y_vals_grad += output_grad;
}
}
......@@ -482,14 +484,14 @@ void learned_nonlin_backward_kernel(
__syncthreads(); // sync threads because we are about to re-use
// output_grad_buf for reduction.
this_params_grad = strided_reduce_sum(N, output_grad_buf, this_params_grad);
this_scaled_param_grad = strided_reduce_sum(N, output_grad_buf, this_scaled_param_grad);
this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad);
__syncthreads(); // sync threads because we are about to re-use
// output_grad_buf.
// Re-use some buffers..
scalar_t *params_grad_buf = x_residual_buf, // [N]
scalar_t *scaled_params_grad_buf = x_residual_buf, // [N] ... a
*y_vals_grad_buf = output_grad_buf; // [N]
if (threadIdx.x < N) {
......@@ -497,7 +499,7 @@ void learned_nonlin_backward_kernel(
// the position in 'params'. To keep the backprop code similar to the CPU
// backprop code we restore that offset here, i.e. use the same layout
// as the params.
params_grad_buf[threadIdx.x + 1] = this_params_grad;
scaled_params_grad_buf[threadIdx.x] = this_scaled_param_grad;
y_vals_grad_buf[threadIdx.x] = this_y_vals_grad;
}
......@@ -514,31 +516,47 @@ void learned_nonlin_backward_kernel(
if (threadIdx.x == 0) {
scalar_t sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) {
// This is like the CPU code but with an offset of 1 for 'params_buf'
// versus 'params_a'.
params_grad_buf[1 + K + i] += sum_positive_grad * scale;
scale_grad += sum_positive_grad * params_buf[K + i];
// This is like the CPU code but with an offset of -1 for indexes into 'params_buf';
// also there is no scale because we are dealing with pre-scaled parameters.
scaled_params_grad_buf[K + i] += sum_positive_grad;
sum_positive_grad += y_vals_grad_buf[K + i];
}
params_grad_buf[0] += scale * scale_grad;
} else if (threadIdx.x == 64) {
scalar_t sum_negative_grad = y_vals_grad_buf[0];
for (int i = K - 1; i >= 0; i--) {
// This is like the CPU code but with an offset of 1 for 'params_buf'
// versus 'params_a'.
params_grad_buf[K - i] -= sum_negative_grad * scale;
scale_grad -= sum_negative_grad * params_buf[K - 1 - i];
// This is like the CPU code but with an offset of 1 for 'params_buf';
// also there is no scale because we are dealing with pre-scaled parameters.
scaled_params_grad_buf[K - 1 - i] -= sum_negative_grad;
sum_negative_grad += y_vals_grad_buf[K - i];
}
}
__syncthreads();
if (threadIdx.x == 64)
params_grad_buf[0] += scale * scale_grad;
if (threadIdx.x < N) {
// this_scaled_param_grad is the gradient w.r.t. params_buf[n] * scale
// which is equal to params[c][n + 1] * scale.
int n = threadIdx.x;
scalar_t this_scaled_param_grad = scaled_params_grad_buf[n],
this_scale_grad = this_scaled_param_grad * params_buf[n],
scale = params_buf[-2],
this_param_grad = scale * this_scaled_param_grad;
// re-use x_residual_buf as 'param_grad_buf'.
x_residual_buf[n + 1] = this_param_grad;
scalar_t scale_grad = tiled_warp_reduce_sum(N, y_vals_grad_buf, this_scale_grad);
if (threadIdx.x == 0)
x_residual_buf[0] = scale_grad * scale; // deriv w.r.t. l.
}
__syncthreads();
}
if (threadIdx.x <= N) {
params_grad[blockIdx.y][c][threadIdx.x] = params_grad_buf[threadIdx.x];
// note, we are re-using x_residual_buf for the params_grad.
params_grad[blockIdx.y][c][threadIdx.x] = x_residual_buf[threadIdx.x];
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment