"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4b20ac52d868fcb8c9522df07e18f998f59bae6d"
Commit fa53fa33 authored by Daniel Povey's avatar Daniel Povey
Browse files

Refactor the forward CPU code for greater simplicity

parent f4081496
...@@ -37,16 +37,15 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -37,16 +37,15 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
sum_positive = 0.0, sum_positive = 0.0,
scale = exp(params_a[c][0]); scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals_a[c][K + i] = sum_positive; scalar_t pos_scaled_param = params_a[c][1 + K + i] * scale,
y_vals_a[c][K - i] = sum_negative; neg_scaled_param = params_a[c][K - i] * scale;
sum_positive += params_a[c][1 + K + i] * scale; y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
sum_negative -= params_a[c][K - i] * scale; sum_positive += pos_scaled_param;
sum_negative -= neg_scaled_param;
y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
} }
// Let the reference point for y_vals_a[c][0] be -K, although the //scalar_t neg_scaled_param = params_a[c][1] * scale;
// interval actually starts at -(K-1). This reference point is //y_vals_a[c][0] = sum_negative + neg_scaled_param * K;
// arbitrary but using it makes our lives easier when processing the
// data.
y_vals_a[c][0] = sum_negative;
} }
auto input_a = input.accessor<scalar_t, 3>(), auto input_a = input.accessor<scalar_t, 3>(),
...@@ -62,14 +61,14 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -62,14 +61,14 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
// so in a sense -K and +K are not special, but we include those // so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions // extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1) // that are < -(K-1) and > (K-1)
scalar_t x = input_a[b][c][t] * inv_scale + K, scalar_t input = input_a[b][c][t],
x_trunc = x; x = input * inv_scale + K;
if (x_trunc < 0) x_trunc = 0; if (x < 0) x = 0;
else if (x_trunc >= N) x_trunc = N - 1; else if (x >= N) x = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x;
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - n) * scale * params_a[c][n + 1] + y_vals_a[c][n]; scalar_t y = input * params_a[c][n + 1] + y_vals_a[c][n];
/* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n, /* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */ x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
output_a[b][c][t] = y; output_a[b][c][t] = y;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment