Commit a4466946 authored by Daniel Povey's avatar Daniel Povey
Browse files

Deriv working in one case at least..

parent c3e61bea
......@@ -361,38 +361,32 @@ void learned_nonlin_backward_kernel(
// will be set to zero for excess threads, and thus won't contribute to
// this_params_grad or this_y_vals_grad.
for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) {
// The following is equivalent to:
// int t = (threadIdx.x % T_inc) + t_offset;
// given that T_inc is a power of 2 and t_offset >= THREADS_PER_BLOCK >= T_inc.
int t = (threadIdx.x & (T_inc - 1)) | t_offset;
scalar_t this_output_grad = 0.0;
if (t < T)
this_output_grad = output_grad[b][c][t];
// The reason we use t % T here rather than only invoking this in some
// threads, is so that the un-needed threads will have a similar
// distribution over 'n' to the needed threads, which will hopefully avoid
// excessive work for some particular 'n' value if too many x values had
// the same 'n'. It might be better to set n to an invalid value for
// out-of-range threads, but as it is, if we are to properly handle
// N==16 we don't have enough bits available in `src_indexes` to do this.
scalar_t this_input = input[b][c][t % T] * inv_scale + K;
input_buf[threadIdx.x] = this_input;
output_grad_buf[threadIdx.x] = this_output_grad;
scalar_t x = this_input;
scalar_t this_input = 0.0, this_output_grad;
if (t < T) {
this_output_grad = output_grad[b][c][t];
this_input = input[b][c][t];
input_buf[threadIdx.x] = this_input;
output_grad_buf[threadIdx.x] = this_output_grad;
}
scalar_t x = this_input * inv_scale + K;
if (x < 0) x = 0;
else if (x >= N) x = N - 1;
// C++ rounds toward zero.
int n = (int)x;
n_buf[threadIdx.x] = (char)n; // 0 <= n < N
// The forward code did:
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
if (t < T)
if (t < T) {
int n = (int)x; // C++ rounds toward zero.
n_buf[threadIdx.x] = (char)n;
input_grad[b][c][t] = this_output_grad * params_buf[n];
} else {
n_buf[threadIdx.x] = 255;
}
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N),
// since N is power of 2
......@@ -465,9 +459,8 @@ void learned_nonlin_backward_kernel(
}
// TODO: remove the next lines
assert(n_buf[threadIdx.x] == 0);
assert(n_buf[threadIdx.x] == 0 || (unsigned char)n_buf[threadIdx.x] == 255);
output_grad_buf[threadIdx.x] = 0.0;
}
}
......
......@@ -90,7 +90,7 @@ def test_learned_nonlin_deriv():
y2 = learned_nonlin(x + delta_x, params, dim = 1)
observed_change = (y_deriv * (y2 - y)).sum()
print(f"for input: pred_change = {pred_change}, observed_change={observed_change}")
if not torch.allclose(pred_change, observed_change, rtol=2.0e-02, atol=1.0e-05):
if not torch.allclose(pred_change, observed_change, rtol=2.0e-02, atol=3.0e-05):
print(f"For changed input, output differs too much: params={params}, input={x}, mod_input={x+delta_x}, y={y}, y2={y2}, diff={y2-y}")
assert 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment