Commit 97f49591 authored by Daniel Povey's avatar Daniel Povey
Browse files

Refactoring using integer rounding, not 100 percent sure this is working

parent 53c52678
...@@ -60,18 +60,19 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -60,18 +60,19 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
// so in a sense -K and +K are not special, but we include those // so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions // extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1) // that are < -(K-1) and > (K-1)
scalar_t x = input_a[b][c][t] * inv_scale + K; scalar_t x = input_a[b][c][t] * inv_scale + K,
int min = 0, diff = K; x_trunc = x;
while (diff > 0) { if (x_trunc < 0) x_trunc = 0;
int mid = min + diff; else if (x_trunc >= N) x_trunc = N - 1;
if (x >= mid) // C++ rounds toward zero.
min = mid; int n = (int) x_trunc;
diff = diff >> 1; // reference point for the lowest linear region is -(K-1), not -K; this is
} // why we have to treat n == 0 separately.
scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min]; scalar_t y = (x - x_rounded) * params_a[c][n + 1] + y_vals_a[c][n];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min, /* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]); x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
output_a[b][c][t] = y; output_a[b][c][t] = y;
} }
} }
...@@ -149,20 +150,20 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -149,20 +150,20 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
// that are < -(K-1) and > (K-1) // that are < -(K-1) and > (K-1)
scalar_t input = input_a[b][c][t], scalar_t input = input_a[b][c][t],
x = input * inv_scale + K, x = input * inv_scale + K,
y_grad = output_grad_a[b][c][t]; y_grad = output_grad_a[b][c][t],
int min = 0, diff = K; x_trunc = x;
while (diff > 0) { if (x_trunc < 0) x_trunc = 0;
int mid = min + diff; else if (x_trunc >= N) x_trunc = N - 1;
if (x >= mid) // C++ rounds toward zero.
min = mid; int n = (int) x_trunc;
diff = diff >> 1; scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
}
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= n < 2*K.
// backprop for: // backprop for:
// scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min]; // scalar_t y = (x - (scalar_t)n) * params_a[c][n + 1] + y_vals_a[c][n];
scalar_t x_grad = y_grad * params_a[c][min + 1]; scalar_t x_grad = y_grad * params_a[c][n + 1];
params_grad_a[c][min + 1] += y_grad * (x - (scalar_t)min); params_grad_a[c][n + 1] += y_grad * x_rounded;
y_vals_grad_a[c][min] += y_grad; y_vals_grad_a[c][n] += y_grad;
// backprop for: x = input * inv_scale + K, // backprop for: x = input * inv_scale + K,
inv_scale_grad += x_grad * input; inv_scale_grad += x_grad * input;
input_grad_a[b][c][t] = x_grad * inv_scale; input_grad_a[b][c][t] = x_grad * inv_scale;
......
...@@ -165,16 +165,15 @@ void learned_nonlin_kernel( ...@@ -165,16 +165,15 @@ void learned_nonlin_kernel(
// images_per_thread_block > 1 if T * images_per_thread_block <= // images_per_thread_block > 1 if T * images_per_thread_block <=
// THREADS_PER_BLOCK. // THREADS_PER_BLOCK.
for (int t = t_start; t < T; t += THREADS_PER_BLOCK) { for (int t = t_start; t < T; t += THREADS_PER_BLOCK) {
scalar_t x = input[b][c][t] * inv_scale + K; scalar_t x = input[b][c][t] * inv_scale + K,
int min = 0, diff = K; x_trunc = x;
while (diff > 0) { if (x_trunc < 0) x_trunc = 0;
int mid = min + diff; else if (x_trunc >= N) x_trunc = N - 1;
if (x >= mid) // C++ rounds toward zero.
min = mid; int n = (int) x_trunc;
diff = diff >> 1; scalar_t x_rounded = (n == 0 ? 1.0 : (scalar_t)n);
}
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - (scalar_t)min) * params_buf[min] + y_vals[min]; scalar_t y = (x - x_rounded) * params_buf[n] + y_vals[n];
output[b][c][t] = y; output[b][c][t] = y;
} }
} }
......
...@@ -64,18 +64,21 @@ def test_learned_nonlin_deriv(): ...@@ -64,18 +64,21 @@ def test_learned_nonlin_deriv():
y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu')) y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same") print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06): if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}") print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}, max-diff = {(y2-y).abs().max()}")
assert(0) assert(0)
y_deriv = torch.rand_like(y) y_deriv = torch.randn_like(y)
y.backward(gradient=y_deriv) y.backward(gradient=y_deriv)
delta = 1.0e-04 delta = 1.0e-04
delta_x = torch.randn_like(x) * delta delta_x = torch.randn_like(x) * delta
pred_change = (x.grad * delta_x).sum() pred_change = (x.grad * delta_x).sum()
observed_change = (y_deriv * (learned_nonlin(x + delta_x, params, dim = 1) - y)).sum() y2 = learned_nonlin(x + delta_x, params, dim = 1)
observed_change = (y_deriv * (y2 - y)).sum()
print(f"for input: pred_change = {pred_change}, observed_change={observed_change}") print(f"for input: pred_change = {pred_change}, observed_change={observed_change}")
assert torch.allclose(pred_change, observed_change, rtol=1.0e-02, atol=1.0e-05) if not torch.allclose(pred_change, observed_change, rtol=1.0e-02, atol=1.0e-05):
print(f"For changed input, output differs too much: params={params}, input={x}, mod_input={x+delta_x}, y={y}, y2={y2}, diff={y2-y}")
assert 0
delta_params = torch.randn_like(params) * delta delta_params = torch.randn_like(params) * delta
pred_change = (params.grad * delta_params).sum() pred_change = (params.grad * delta_params).sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment