Commit 06e369c9 authored by Daniel Povey's avatar Daniel Povey
Browse files

Make the loop a bit simpler

parent 97f49591
...@@ -36,16 +36,17 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -36,16 +36,17 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
scalar_t sum_negative = 0.0, scalar_t sum_negative = 0.0,
sum_positive = 0.0, sum_positive = 0.0,
scale = exp(params_a[c][0]); scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals_a[c][K + i] = sum_positive * scale; y_vals_a[c][K + i] = sum_positive;
y_vals_a[c][K - i] = sum_negative * scale; y_vals_a[c][K - i] = sum_negative;
sum_positive += params_a[c][1 + K + i]; sum_positive += params_a[c][1 + K + i] * scale;
sum_negative -= params_a[c][K - i]; sum_negative -= params_a[c][K - i] * scale;
} }
// the reference point for the lowest, half-infinite interval (the one // Let the reference point for y_vals_a[c][0] be -K, although the
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals. // interval actually starts at -(K-1). This reference point is
y_vals_a[c][0] = y_vals_a[c][1]; // arbitrary but using it makes our lives easier when processing the
// data.
y_vals_a[c][0] = sum_negative;
} }
auto input_a = input.accessor<scalar_t, 3>(), auto input_a = input.accessor<scalar_t, 3>(),
...@@ -66,11 +67,8 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -66,11 +67,8 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
else if (x_trunc >= N) x_trunc = N - 1; else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x_trunc;
// reference point for the lowest linear region is -(K-1), not -K; this is
// why we have to treat n == 0 separately.
scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - x_rounded) * params_a[c][n + 1] + y_vals_a[c][n]; scalar_t y = (x - n) * params_a[c][n + 1] + y_vals_a[c][n];
/* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n, /* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */ x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
output_a[b][c][t] = y; output_a[b][c][t] = y;
...@@ -130,8 +128,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -130,8 +128,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
sum_negative -= params_a[c][K - i] * scale; sum_negative -= params_a[c][K - i] * scale;
} }
// the reference point for the lowest, half-infinite interval (the one // the reference point for the lowest, half-infinite interval (the one
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals. // starting at x=-(K-1) is x=-K; this is arbitrary but makes the
y_vals_a[c][0] = y_vals_a[c][1]; // computation more regular.
y_vals_a[c][0] = sum_negative;
} }
auto input_a = input.accessor<scalar_t, 3>(), auto input_a = input.accessor<scalar_t, 3>(),
...@@ -156,13 +155,11 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -156,13 +155,11 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
else if (x_trunc >= N) x_trunc = N - 1; else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x_trunc;
scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
// OK, at this point, 0 <= n < 2*K. // OK, at this point, 0 <= n < 2*K.
// backprop for: // backprop for:
// scalar_t y = (x - (scalar_t)n) * params_a[c][n + 1] + y_vals_a[c][n]; // scalar_t y = (x - (scalar_t)n) * params_a[c][n + 1] + y_vals_a[c][n];
scalar_t x_grad = y_grad * params_a[c][n + 1]; scalar_t x_grad = y_grad * params_a[c][n + 1];
params_grad_a[c][n + 1] += y_grad * x_rounded; params_grad_a[c][n + 1] += y_grad * (x - (scalar_t)n);
y_vals_grad_a[c][n] += y_grad; y_vals_grad_a[c][n] += y_grad;
// backprop for: x = input * inv_scale + K, // backprop for: x = input * inv_scale + K,
inv_scale_grad += x_grad * input; inv_scale_grad += x_grad * input;
...@@ -174,27 +171,22 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -174,27 +171,22 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
} }
// Now do the backprop for the loop above where we set y_vals_a. // Now do the backprop for the loop above where we set y_vals_a.
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
// backprop for: y_vals_a[c][0] = y_vals_a[c][1];
y_vals_grad_a[c][1] += y_vals_grad_a[c][0];
scalar_t scale = exp(params_a[c][0]), scalar_t scale = exp(params_a[c][0]),
inv_scale = 1.0 / scale,
scale_grad = 0.0, scale_grad = 0.0,
sum_negative_grad = 0.0, sum_negative_grad = y_vals_grad_a[c][0], // backprop for: y_vals_a[c][0] = sum_negative
sum_positive_grad = 0.0; sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) { for (int i = K - 1; i >= 0; i--) {
// backprop for: sum_negative -= params_a[c][K - i]; // backprop for: sum_negative -= params_a[c][K - i] * scale;
params_grad_a[c][K - i] -= sum_negative_grad; params_grad_a[c][K - i] -= sum_negative_grad * scale;
// backprop for: sum_positive += params_a[c][1 + K + i] * scale; // backprop for: sum_positive += params_a[c][1 + K + i] * scale;
params_grad_a[c][1 + K + i] += sum_positive_grad; params_grad_a[c][1 + K + i] += sum_positive_grad * scale;
// backprop for: y_vals_a[c][K - i] = sum_negative * scale; // .. and the contributions to scale_grad for the 2 expressions above..
sum_negative_grad += y_vals_grad_a[c][K - i] * scale; scale_grad += (sum_positive_grad * params_a[c][1 + K + i] -
// The next code line is equivalent to: sum_negative_grad * params_a[c][K - i]);
// scale_grad += y_vals_grad_a[c][K - i] * sum_negative, substituting: // backprop for: y_vals_a[c][K - i] = sum_negative
// sum_negative == y_vals_a[c][K - i] / scale sum_negative_grad += y_vals_grad_a[c][K - i];
scale_grad += y_vals_grad_a[c][K - i] * y_vals_a[c][K - i] * inv_scale; // backprop for: y_vals_a[c][K + i] = sum_positive
// backprop for: y_vals_a[c][K + i] = sum_positive * scale; sum_positive_grad += y_vals_grad_a[c][K + i];
sum_positive_grad += y_vals_grad_a[c][K + i] * scale;
scale_grad += y_vals_grad_a[c][K + i] * y_vals_a[c][K + i] * inv_scale;
} }
// Backprop for: scale = exp(params_a[c][0]), // Backprop for: scale = exp(params_a[c][0]),
params_grad_a[c][0] += scale * scale_grad; params_grad_a[c][0] += scale * scale_grad;
......
...@@ -148,7 +148,8 @@ void learned_nonlin_kernel( ...@@ -148,7 +148,8 @@ void learned_nonlin_kernel(
y_vals[K + isign] = sum * scale; y_vals[K + isign] = sum * scale;
sum += params_buf[Koffset + isign]; sum += params_buf[Koffset + isign];
} }
y_vals[0] = y_vals[1]; // Both threads do this but it's OK. if (threadIdx.x != 0) // sum_negative
y_vals[0] = sum * scale;
} }
__syncthreads(); __syncthreads();
scalar_t inv_scale = params_buf[-3]; scalar_t inv_scale = params_buf[-3];
...@@ -171,9 +172,8 @@ void learned_nonlin_kernel( ...@@ -171,9 +172,8 @@ void learned_nonlin_kernel(
else if (x_trunc >= N) x_trunc = N - 1; else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x_trunc;
scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - x_rounded) * params_buf[n] + y_vals[n]; scalar_t y = (x - n) * params_buf[n] + y_vals[n];
output[b][c][t] = y; output[b][c][t] = y;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment