Commit 2ccbb505 authored by Daniel Povey's avatar Daniel Povey
Browse files

Get simplified version of CUDA forward working

parent 7ce3c947
...@@ -131,10 +131,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -131,10 +131,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
scalar_t scale = exp(params_a[c][0]), scalar_t inv_scale = exp(-params_a[c][0]);
inv_scale = 1.0 / scale,
inv_scale_grad = 0.0,
scale_grad = 0.0;
for (int t = 0; t < T; t++) { for (int t = 0; t < T; t++) {
scalar_t input = input_a[b][c][t], scalar_t input = input_a[b][c][t],
x = input * inv_scale + K, x = input * inv_scale + K,
......
...@@ -108,50 +108,42 @@ void learned_nonlin_kernel( ...@@ -108,50 +108,42 @@ void learned_nonlin_kernel(
// spaces between here and // spaces between here and
// `params_buf` for storing scale // `params_buf` for storing scale
// and inv_scale and l == params[c][0]. // and inv_scale and l == params[c][0].
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Contains params[c][1] * scale through params[c][N] * scale, *params_buf = (scalar_t*) y_vals + 3 + N; // [N]. params_buf[n] ontains params[c][n-1].
// params_buf[-1] contains params[c][0] == log of scale; // params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale. // params_buf[-2] contains scale, params_buf[-3]
// contains inv_scale.
// Load parameters // Load parameters
if (threadIdx.x <= N) if (threadIdx.x <= N)
params_buf[threadIdx.x - 1] = params[c][threadIdx.x]; params_buf[threadIdx.x - 1] = params[c][threadIdx.x];
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
scalar_t scale = exp(params_buf[-1]), scalar_t scale = exp(params_buf[-1]);
inv_scale = 1.0 / scale;
params_buf[-2] = scale; params_buf[-2] = scale;
params_buf[-3] = inv_scale; params_buf[-3] = 1.0 / scale;
} }
__syncthreads(); __syncthreads();
if (threadIdx.x < N) {
scalar_t scale = params_buf[-2];
params_buf[threadIdx.x] = params_buf[threadIdx.x] * scale;
}
__syncthreads();
// The easiest way to understand this code is to compare it with the CPU code // The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp. // in learned_nonlin_cpu.cpp.
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
scalar_t sum_positive = 0.0; scalar_t scale = params_buf[-2],
sum_positive = 0.0;
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals[K + i] = sum_positive; // params_buf is indexed with an index one less than params.
// versus the CPU code, the params_buf is indexed off by 1; and it already scalar_t pos_scaled_param = params_buf[K + i] * scale;
// contains the factor "scale". y_vals[K + i] = sum_positive - pos_scaled_param * i;
sum_positive += params_buf[K + i]; sum_positive += pos_scaled_param;
} }
} else if (threadIdx.x == 64) { } else if (threadIdx.x == 64) {
scalar_t sum_negative = 0.0; scalar_t scale = params_buf[-2],
sum_negative = 0.0;
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals[K - i] = sum_negative; scalar_t neg_scaled_param = params_buf[K - 1 - i] * scale;
// versus the CPU code, the params_buf is indexed off by 1; and it already sum_negative -= neg_scaled_param;
// contains the factor "scale". y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
sum_negative -= params_buf[K - 1 - i];
} }
y_vals[0] = sum_negative;
} }
__syncthreads(); __syncthreads();
...@@ -169,15 +161,15 @@ void learned_nonlin_kernel( ...@@ -169,15 +161,15 @@ void learned_nonlin_kernel(
// images_per_thread_block > 1 if T * images_per_thread_block <= // images_per_thread_block > 1 if T * images_per_thread_block <=
// THREADS_PER_BLOCK. // THREADS_PER_BLOCK.
for (int t = t_start; t < T; t += THREADS_PER_BLOCK) { for (int t = t_start; t < T; t += THREADS_PER_BLOCK) {
scalar_t x = input[b][c][t] * inv_scale + K, scalar_t this_input = input[b][c][t],
x_trunc = x; x = this_input * inv_scale + K;
if (x_trunc < 0) x_trunc = 0; if (x < 0) x = 0;
else if (x_trunc >= N) x_trunc = N - 1; else if (x >= N) x = N - 1;
// C++ rounds toward zero. // C++ rounds toward zero.
int n = (int) x_trunc; int n = (int) x;
// OK, at this point, 0 <= min < N. Versus the CPU code, we removed the // OK, at this point, 0 <= min < N. Versus the CPU code, we removed the
// factor of 'scale' because params_buf already has that factor. // factor of 'scale' because params_buf already has that factor.
output[b][c][t] = (x - n) * params_buf[n] + y_vals[n]; output[b][c][t] = this_input * params_buf[n] + y_vals[n];
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment