Commit 97f49591 authored by Daniel Povey's avatar Daniel Povey
Browse files

Refactoring using integer rounding, not 100 percent sure this is working

parent 53c52678
......@@ -60,18 +60,19 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t x = input_a[b][c][t] * inv_scale + K;
int min = 0, diff = K;
while (diff > 0) {
int mid = min + diff;
if (x >= mid)
min = mid;
diff = diff >> 1;
}
scalar_t x = input_a[b][c][t] * inv_scale + K,
x_trunc = x;
if (x_trunc < 0) x_trunc = 0;
else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero.
int n = (int) x_trunc;
// reference point for the lowest linear region is -(K-1), not -K; this is
// why we have to treat n == 0 separately.
scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
// OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
scalar_t y = (x - x_rounded) * params_a[c][n + 1] + y_vals_a[c][n];
/* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
output_a[b][c][t] = y;
}
}
......@@ -149,20 +150,20 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
// that are < -(K-1) and > (K-1)
scalar_t input = input_a[b][c][t],
x = input * inv_scale + K,
y_grad = output_grad_a[b][c][t];
int min = 0, diff = K;
while (diff > 0) {
int mid = min + diff;
if (x >= mid)
min = mid;
diff = diff >> 1;
}
// OK, at this point, 0 <= min < 2*K.
y_grad = output_grad_a[b][c][t],
x_trunc = x;
if (x_trunc < 0) x_trunc = 0;
else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero.
int n = (int) x_trunc;
scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min];
scalar_t x_grad = y_grad * params_a[c][min + 1];
params_grad_a[c][min + 1] += y_grad * (x - (scalar_t)min);
y_vals_grad_a[c][min] += y_grad;
// scalar_t y = (x - (scalar_t)n) * params_a[c][n + 1] + y_vals_a[c][n];
scalar_t x_grad = y_grad * params_a[c][n + 1];
params_grad_a[c][n + 1] += y_grad * x_rounded;
y_vals_grad_a[c][n] += y_grad;
// backprop for: x = input * inv_scale + K,
inv_scale_grad += x_grad * input;
input_grad_a[b][c][t] = x_grad * inv_scale;
......
......@@ -165,16 +165,15 @@ void learned_nonlin_kernel(
// images_per_thread_block > 1 if T * images_per_thread_block <=
// THREADS_PER_BLOCK.
for (int t = t_start; t < T; t += THREADS_PER_BLOCK) {
scalar_t x = input[b][c][t] * inv_scale + K;
int min = 0, diff = K;
while (diff > 0) {
int mid = min + diff;
if (x >= mid)
min = mid;
diff = diff >> 1;
}
scalar_t x = input[b][c][t] * inv_scale + K,
x_trunc = x;
if (x_trunc < 0) x_trunc = 0;
else if (x_trunc >= N) x_trunc = N - 1;
// C++ rounds toward zero.
int n = (int) x_trunc;
scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
// OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - (scalar_t)min) * params_buf[min] + y_vals[min];
scalar_t y = (x - x_rounded) * params_buf[n] + y_vals[n];
output[b][c][t] = y;
}
}
......
......@@ -64,18 +64,21 @@ def test_learned_nonlin_deriv():
y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}, max-diff = {(y2-y).abs().max()}")
assert(0)
y_deriv = torch.rand_like(y)
y_deriv = torch.randn_like(y)
y.backward(gradient=y_deriv)
delta = 1.0e-04
delta_x = torch.randn_like(x) * delta
pred_change = (x.grad * delta_x).sum()
observed_change = (y_deriv * (learned_nonlin(x + delta_x, params, dim = 1) - y)).sum()
y2 = learned_nonlin(x + delta_x, params, dim = 1)
observed_change = (y_deriv * (y2 - y)).sum()
print(f"for input: pred_change = {pred_change}, observed_change={observed_change}")
assert torch.allclose(pred_change, observed_change, rtol=1.0e-02, atol=1.0e-05)
if not torch.allclose(pred_change, observed_change, rtol=1.0e-02, atol=1.0e-05):
print(f"For changed input, output differs too much: params={params}, input={x}, mod_input={x+delta_x}, y={y}, y2={y2}, diff={y2-y}")
assert 0
delta_params = torch.randn_like(params) * delta
pred_change = (params.grad * delta_params).sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment