Commit 6a77cb45 authored by Daniel Povey's avatar Daniel Povey
Browse files

Get CUDA forward working correctly.

parent 74897fd5
...@@ -54,7 +54,8 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -54,7 +54,8 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
scalar_t inv_scale = exp(-params_a[c][0]); scalar_t scale = exp(params_a[c][0]),
inv_scale = 1.0 / scale;
for (int t = 0; t < T; t++) { for (int t = 0; t < T; t++) {
// `x` is the scaled input x plus an offset so that -K maps to 0. // `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1), // Note: the discontinuities in our function are at -(K-1) ... +(K+1),
...@@ -68,7 +69,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -68,7 +69,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x_trunc;
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - n) * params_a[c][n + 1] + y_vals_a[c][n]; scalar_t y = (x - n) * scale * params_a[c][n + 1] + y_vals_a[c][n];
/* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n, /* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */ x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
output_a[b][c][t] = y; output_a[b][c][t] = y;
...@@ -139,8 +140,10 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -139,8 +140,10 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
scalar_t inv_scale = exp(-params_a[c][0]), scalar_t scale = exp(params_a[c][0]),
inv_scale_grad = 0.0; inv_scale = 1.0 / scale,
inv_scale_grad = 0.0,
scale_grad = 0.0;
for (int t = 0; t < T; t++) { for (int t = 0; t < T; t++) {
// `x` is the scaled input x plus an offset so that -K maps to 0. // `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1), // Note: the discontinuities in our function are at -(K-1) ... +(K+1),
...@@ -157,16 +160,22 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -157,16 +160,22 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
int n = (int) x_trunc; int n = (int) x_trunc;
// OK, at this point, 0 <= n < 2*K. // OK, at this point, 0 <= n < 2*K.
// backprop for: // backprop for:
// scalar_t y = (x - (scalar_t)n) * params_a[c][n + 1] + y_vals_a[c][n]; // scalar_t x_residual_scaled = (x - (scalar_t)n) * scale
scalar_t x_grad = y_grad * params_a[c][n + 1]; // scalar_t y = x_residual_scaled * params_a[c][n + 1] + y_vals_a[c][n];
params_grad_a[c][n + 1] += y_grad * (x - (scalar_t)n); scalar_t x_residual_scaled = (x - n) * scale,
x_residual_scaled_grad = y_grad * params_a[c][n + 1],
x_grad = x_residual_scaled_grad * scale;
scale_grad += x_residual_scaled_grad * (x - (scalar_t)n);
params_grad_a[c][n + 1] += y_grad * x_residual_scaled;
y_vals_grad_a[c][n] += y_grad; y_vals_grad_a[c][n] += y_grad;
// backprop for: x = input * inv_scale + K, // backprop for: x = input * inv_scale + K,
inv_scale_grad += x_grad * input; inv_scale_grad += x_grad * input;
input_grad_a[b][c][t] = x_grad * inv_scale; input_grad_a[b][c][t] = x_grad * inv_scale;
} }
// Do the backprop for: inv_scale = exp(-params_a[c][0]) // Do the backprop for:
params_grad_a[c][0] -= inv_scale * inv_scale_grad; // scale = exp(params_a[c][0]);
// inv_scale = exp(-params_a[c][0]);
params_grad_a[c][0] += (scale * scale_grad - inv_scale * inv_scale_grad);
} }
} }
// Now do the backprop for the loop above where we set y_vals_a. // Now do the backprop for the loop above where we set y_vals_a.
......
...@@ -108,54 +108,53 @@ void learned_nonlin_kernel( ...@@ -108,54 +108,53 @@ void learned_nonlin_kernel(
// spaces between here and // spaces between here and
// `params_buf` for storing scale // `params_buf` for storing scale
// and inv_scale and l == params[c][0]. // and inv_scale and l == params[c][0].
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Caution: contains params[c][1] through params[c][N]. *params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Contains params[c][1] * scale through params[c][N] * scale,
// params_buf[-1] contains params[c][0] == log of scale; // params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale. // params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// Load parameters // Load parameters
if (threadIdx.x <= N) if (threadIdx.x <= N)
params_buf[threadIdx.x - 1] = params[c][threadIdx.x]; params_buf[threadIdx.x - 1] = params[c][threadIdx.x];
__syncthreads(); __syncthreads();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp. if (threadIdx.x == 0) {
// TODO: replace this with easier-to-understand code.
if ((((int)threadIdx.x & ~(int)64)) == 0) {
// threadIdx.x == 0 or 64 (we choose 64 because it's >= the max known warp
// size). These are in separate warps so we can allow them to do separate
// jobs. This code takes linear time in K which is not at all ideal and
// could be improved if K is largish, but it shouldn't dominate the total
// time taken if we are processing a lot of data; and anyway, we doubt that
// K will be need to be more than 4 or 8 or so, so the potential savings are
// quite small.
scalar_t scale = exp(params_buf[-1]), scalar_t scale = exp(params_buf[-1]),
inv_scale = 1.0 / scale; inv_scale = 1.0 / scale;
params_buf[-2] = scale; // both threads write these but it's OK, it's the params_buf[-2] = scale;
// same value.
params_buf[-3] = inv_scale; params_buf[-3] = inv_scale;
int sign,
Koffset; // Koffset == K for threads handling sum_positive and K - 1
// for threads handling sum_negative, see
// learned_nonlin_cpu.cpp for reference code. This would be K
// + 1 and K respectively, except our params_buf has its index
// shifted by one versus params.
if (threadIdx.x == 0) { // sum_positive
sign = 1;
Koffset = K;
} else { // threadIdx.x == 64. sum_negative.
scale *= -1; // this is a local variable..
sign = -1;
Koffset = K - 1;
} }
scalar_t sum = 0.0; __syncthreads();
if (threadIdx.x < N) {
scalar_t scale = params_buf[-2];
params_buf[threadIdx.x] = params_buf[threadIdx.x] * scale;
}
__syncthreads();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
if (threadIdx.x == 0) {
scalar_t sum_positive = 0.0;
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
int isign = i * sign; y_vals[K + i] = sum_positive;
y_vals[K + isign] = sum * scale; // versus the CPU code, the params_buf is indexed off by 1; and it already
sum += params_buf[Koffset + isign]; // contains the factor "scale".
sum_positive += params_buf[K + i];
} }
if (threadIdx.x != 0) // sum_negative
y_vals[0] = sum * scale; } else if (threadIdx.x == 64) {
scalar_t sum_negative = 0.0;
for (int i = 0; i < K; i++) {
y_vals[K - i] = sum_negative;
// versus the CPU code, the params_buf is indexed off by 1; and it already
// contains the factor "scale".
sum_negative -= params_buf[K - 1 - i];
}
y_vals[0] = sum_negative;
} }
__syncthreads(); __syncthreads();
scalar_t inv_scale = params_buf[-3]; scalar_t inv_scale = params_buf[-3];
int T_inc = THREADS_PER_BLOCK / images_per_thread_block, int T_inc = THREADS_PER_BLOCK / images_per_thread_block,
...@@ -176,7 +175,8 @@ void learned_nonlin_kernel( ...@@ -176,7 +175,8 @@ void learned_nonlin_kernel(
else if (x_trunc >= N) x_trunc = N - 1; else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x_trunc;
// OK, at this point, 0 <= min < N. // OK, at this point, 0 <= min < N. Versus the CPU code, we removed the
// factor of 'scale' because params_buf already has that factor.
output[b][c][t] = (x - n) * params_buf[n] + y_vals[n]; output[b][c][t] = (x - n) * params_buf[n] + y_vals[n];
} }
} }
......
...@@ -76,7 +76,7 @@ def test_learned_nonlin_deriv(): ...@@ -76,7 +76,7 @@ def test_learned_nonlin_deriv():
y2 = learned_nonlin(x + delta_x, params, dim = 1) y2 = learned_nonlin(x + delta_x, params, dim = 1)
observed_change = (y_deriv * (y2 - y)).sum() observed_change = (y_deriv * (y2 - y)).sum()
print(f"for input: pred_change = {pred_change}, observed_change={observed_change}") print(f"for input: pred_change = {pred_change}, observed_change={observed_change}")
if not torch.allclose(pred_change, observed_change, rtol=1.0e-02, atol=1.0e-05): if not torch.allclose(pred_change, observed_change, rtol=2.0e-02, atol=1.0e-05):
print(f"For changed input, output differs too much: params={params}, input={x}, mod_input={x+delta_x}, y={y}, y2={y2}, diff={y2-y}") print(f"For changed input, output differs too much: params={params}, input={x}, mod_input={x+delta_x}, y={y}, y2={y2}, diff={y2-y}")
assert 0 assert 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment