Commit 7ce3c947 authored by Daniel Povey's avatar Daniel Povey
Browse files

Revise CPU backward code, simpler now.

parent fa53fa33
...@@ -44,8 +44,6 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -44,8 +44,6 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
sum_negative -= neg_scaled_param; sum_negative -= neg_scaled_param;
y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1); y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
} }
//scalar_t neg_scaled_param = params_a[c][1] * scale;
//y_vals_a[c][0] = sum_negative + neg_scaled_param * K;
} }
auto input_a = input.accessor<scalar_t, 3>(), auto input_a = input.accessor<scalar_t, 3>(),
...@@ -68,10 +66,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -68,10 +66,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x; int n = (int) x;
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < 2*K.
scalar_t y = input * params_a[c][n + 1] + y_vals_a[c][n]; output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
/* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
output_a[b][c][t] = y;
} }
} }
}})); }}));
...@@ -122,17 +117,14 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -122,17 +117,14 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
sum_positive = 0.0, sum_positive = 0.0,
scale = exp(params_a[c][0]); scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals_a[c][K + i] = sum_positive; scalar_t pos_scaled_param = params_a[c][1 + K + i] * scale,
y_vals_a[c][K - i] = sum_negative; neg_scaled_param = params_a[c][K - i] * scale;
sum_positive += params_a[c][1 + K + i] * scale; y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
sum_negative -= params_a[c][K - i] * scale; sum_positive += pos_scaled_param;
sum_negative -= neg_scaled_param;
y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
} }
// the reference point for the lowest, half-infinite interval (the one
// starting at x=-(K-1) is x=-K; this is arbitrary but makes the
// computation more regular.
y_vals_a[c][0] = sum_negative;
} }
auto input_a = input.accessor<scalar_t, 3>(), auto input_a = input.accessor<scalar_t, 3>(),
output_grad_a = output_grad.accessor<scalar_t, 3>(), output_grad_a = output_grad.accessor<scalar_t, 3>(),
input_grad_a = input_grad.accessor<scalar_t, 3>(); input_grad_a = input_grad.accessor<scalar_t, 3>();
...@@ -144,57 +136,47 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -144,57 +136,47 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
inv_scale_grad = 0.0, inv_scale_grad = 0.0,
scale_grad = 0.0; scale_grad = 0.0;
for (int t = 0; t < T; t++) { for (int t = 0; t < T; t++) {
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t input = input_a[b][c][t], scalar_t input = input_a[b][c][t],
x = input * inv_scale + K, x = input * inv_scale + K,
y_grad = output_grad_a[b][c][t], output_grad = output_grad_a[b][c][t];
x_trunc = x; if (x < 0) x = 0;
if (x_trunc < 0) x_trunc = 0; else if (x >= N) x = N - 1;
else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x;
// OK, at this point, 0 <= n < 2*K. // OK, at this point, 0 <= n < 2*K.
// backprop for: // backprop for:
// scalar_t x_residual_scaled = (x - (scalar_t)n) * scale // output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
// scalar_t y = x_residual_scaled * params_a[c][n + 1] + y_vals_a[c][n]; params_grad_a[c][n + 1] += output_grad * input;
scalar_t x_residual_scaled = (x - n) * scale, y_vals_grad_a[c][n] += output_grad;
x_residual_scaled_grad = y_grad * params_a[c][n + 1], input_grad_a[b][c][t] = output_grad * params_a[c][n + 1];
x_grad = x_residual_scaled_grad * scale;
scale_grad += x_residual_scaled_grad * (x - (scalar_t)n);
params_grad_a[c][n + 1] += y_grad * x_residual_scaled;
y_vals_grad_a[c][n] += y_grad;
// backprop for: x = input * inv_scale + K,
inv_scale_grad += x_grad * input;
input_grad_a[b][c][t] = x_grad * inv_scale;
} }
// Do the backprop for:
// scale = exp(params_a[c][0]);
// inv_scale = exp(-params_a[c][0]);
params_grad_a[c][0] += (scale * scale_grad - inv_scale * inv_scale_grad);
} }
} }
// Now do the backprop for the loop above where we set y_vals_a. // Now do the backprop for the loop above where we set y_vals_a.
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
scalar_t scale = exp(params_a[c][0]), scalar_t scale = exp(params_a[c][0]),
scale_grad = 0.0, scale_grad = 0.0,
sum_negative_grad = y_vals_grad_a[c][0], // backprop for: y_vals_a[c][0] = sum_negative sum_negative_grad = 0.0,
sum_positive_grad = 0.0; sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) { for (int i = K - 1; i >= 0; i--) {
// backprop for: sum_negative -= params_a[c][K - i] * scale; // Backprop for: y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
params_grad_a[c][K - i] -= sum_negative_grad * scale; scalar_t y_grad_neg = y_vals_grad_a[c][K - i - 1];
// backprop for: sum_positive += params_a[c][1 + K + i] * scale; sum_negative_grad += y_grad_neg;
params_grad_a[c][1 + K + i] += sum_positive_grad * scale; scalar_t neg_scaled_param_grad = y_grad_neg * (i + 1);
// .. and the contributions to scale_grad for the 2 expressions above.. // Backprop for: sum_negative -= neg_scaled_param;
scale_grad += (sum_positive_grad * params_a[c][1 + K + i] - neg_scaled_param_grad -= sum_negative_grad;
sum_negative_grad * params_a[c][K - i]); // Backprop for: sum_positive += pos_scaled_param;
// backprop for: y_vals_a[c][K - i] = sum_negative scalar_t pos_scaled_param_grad = sum_positive_grad;
sum_negative_grad += y_vals_grad_a[c][K - i]; // Backprop for: y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
// backprop for: y_vals_a[c][K + i] = sum_positive scalar_t y_grad_pos = y_vals_grad_a[c][K + i];
sum_positive_grad += y_vals_grad_a[c][K + i]; pos_scaled_param_grad -= i * y_grad_pos;
sum_positive_grad += y_grad_pos;
// Backprop for: pos_scaled_param = params_a[c][1 + K + i] * scale,
// and: neg_scaled_param = params_a[c][K - i] * scale;
params_grad_a[c][1 + K + i] += pos_scaled_param_grad * scale;
params_grad_a[c][K - i] += neg_scaled_param_grad * scale;
scale_grad += (pos_scaled_param_grad * params_a[c][1 + K + i] +
neg_scaled_param_grad * params_a[c][K - i]);
} }
// Backprop for: scale = exp(params_a[c][0]), // Backprop for: scale = exp(params_a[c][0]),
params_grad_a[c][0] += scale * scale_grad; params_grad_a[c][0] += scale * scale_grad;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment