Commit 1dc0d559 authored by Daniel Povey's avatar Daniel Povey
Browse files

Make the forward code for CPU/CUDA give the same result

parent 8f096574
......@@ -61,6 +61,14 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
of the linear parts of the piecewise linear function.
The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
output: The transformed input, shape (B , C, T)
images_per_thread_block: The number of images processed by each thread
block. The calling code must guarantee that this is a power
of 2, and that EITHER:
THREADS_PER_BLOCK / images_per_thread_block >= T
OR
images_per_thread_block == 1
.. this is used for a small optimization.
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N values of type scalar_t.
......@@ -74,8 +82,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
When we invoke this kernel, we'll invoke it as:
learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * numel, where
numel = 2 * (kH * kW) + max(blockDim.x, (opatchH + kH - 1) * (patchW + kW - 1))
bytesShared = sizeof(shared_t) * (2N + 3)
*/
extern __shared__ int extern_buf[];
......@@ -84,8 +91,8 @@ __global__
void learned_nonlin_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> input, // B, C, T, i.e. batch, channels, time
torch::PackedTensorAccessor32<scalar_t, 2> params, // C, N + 1
torch::PackedTensorAccessor32<scalar_t, 3> output) { // B, C, T
torch::PackedTensorAccessor32<scalar_t, 3> output,
int images_per_thread_block) { // B, C, T
const int B = input.size(0),
C = input.size(1),
......@@ -93,34 +100,85 @@ void learned_nonlin_kernel(
N = params.size(1) - 1,
K = N / 2; // Note: N and K are powers fo 2, with K >= 1.
const int c = blockIdx.x, // c is channel index
b = blockIdx.y; // b is batch index; we'll iterate over b
const int c = blockIdx.x; // c is channel index
scalar_t *y_vals_buf = (scalar_t*) extern_buf, // [N]
*params_buf = (scalar_t*) y_vals_buf + N; // [N]
scalar_t *y_vals = (scalar_t*) extern_buf, // [N], actually there are two
// spaces between here and
// `params_buf` for storing scale
// and inv_scale.
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Caution: contains params[c][1] through params[c][N].
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// Load parameters
for (int n = threadIdx.x; n < N + 1; n += THREADS_PER_BLOCK) {
for (int n = threadIdx.x; n <= N; n += THREADS_PER_BLOCK) {
params_buf[n - 1] = params[c][n];
}
if (threadIdx.x == 0) {
scalar_t scale = exp(params_buf[-1]),
inv_scale = 1.0 / scale;
params_buf[-1] = scale;
params_buf[-2] = inv_scale;
} else if (threadIdx.x & ~96 == 0) {
// threadIdx.x == 32 or 64. These, and 0, are in separate warps so we can
__syncthreads();
// The easiest way to understand this code is to compare it with the CPU code
// in learned_nonlin_cpu.cpp.
if ((((int)threadIdx.x & ~(int)32)) == 0) {
// threadIdx.x == 0 or 32. These are in separate warps so we can
// allow them to do separate jobs. This code takes linear time in K which
// is not at all ideal and could be improved if K is largish, but it shouldn't
// dominate the total time taken if we are processing a lot of data;
// and anyway, we doubt that K will be need to be more than 4 or 8 or so,
// so the potential savings are quite small.
scalar_t scale = exp(params_buf[-1]),
inv_scale = 1.0 / scale;
params_buf[-2] = scale; // both threads write these but it's OK, it's the
// same value.
params_buf[-3] = inv_scale;
int sign,
Koffset; // Koffset == K for threads handling sum_positive and K - 1
// for threads handling sum_negative, see
// learned_nonlin_cpu.cpp for reference code. This would be K
// + 1 and K respectively, except our params_buf has its index
// shifted by one versus params.
if (threadIdx.x == 0) { // sum_positive
sign = 1;
Koffset = K;
} else { // threadIdx.x == 32. sum_negative.
scale *= -1; // this is a local variable..
sign = -1;
Koffset = K - 1;
}
scalar_t sum = 0.0;
for (int i = 0; i < K; i++) {
int isign = i * sign;
y_vals[K + isign] = sum * scale;
printf("c = %d, y_vals[%d] = %f\n", c, K + isign, sum * scale);
sum += params_buf[Koffset + isign];
}
y_vals[0] = y_vals[1]; // Both threads do this but it's OK.
}
__syncthreads();
scalar_t scale = params_buf[-1],
inv_scale = params_buf[-2];
scalar_t inv_scale = params_buf[-3];
int T_inc = THREADS_PER_BLOCK / images_per_thread_block,
image_offset = threadIdx.x / T_inc,
t_start = threadIdx.x % T_inc;
for (int b = blockIdx.y * images_per_thread_block + image_offset;
b < B; b += gridDim.y * images_per_thread_block) {
// We do "t += THREADS_PER_BLOCK" instead of t += (THREADS_PER_BLOCK /
// images_per_thread_block) as a small optimization because the only case we
// really need to loop is when images_per_thread_block == 1:a we only let
// images_per_thread_block > 1 if T * images_per_thread_block <=
// THREADS_PER_BLOCK.
for (int t = t_start; t < T; t += THREADS_PER_BLOCK) {
scalar_t x = input[b][c][t] * inv_scale + K;
int min = 0, diff = K;
while (diff > 0) {
int mid = min + diff;
if (x >= mid)
min = mid;
diff = diff >> 1;
}
// OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - (scalar_t)min) * params_buf[min] + y_vals[min];
output[b][c][t] = y;
}
}
}
......@@ -197,8 +255,6 @@ void learned_nonlin_kernel(
bytesShared = sizeof(shared_t) * numel, where
numel = 4 * (kH * kW) + 3 * (ppatchH * ppatchW) + blockDim.x
*/
template <typename scalar_t>
__global__
void learned_nonlin_kernel_backward(
......@@ -502,35 +558,51 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor output = torch::empty({B, C, T}, opts);
// TODO: make this empty
torch::Tensor output = torch::ones({B, C, T}, opts);
if (C * B * T == 0)
return output;
// The number of thread blocks is at least C (the number of channels), but
// if the number of channels is small we may split further on the batch.
int images_per_thread_block = 1;
while (images_per_thread_block * 2 * T <= THREADS_PER_BLOCK)
images_per_thread_block *= 2;
int batches_per_block = 1;
if (C * batches_per_block < 128) {
// Aim for at least 128 thread blocks..
batches_per_block = 128 / C;
if (batches_per_block > B)
batches_per_block = B;
}
int grid_dim_y = 1;
// If the number of channels is quite small (<128) we can launch more thread
// groups, splitting on the batch index.
while (C * grid_dim_y < 128)
grid_dim_y *= 2;
// B_reduced is the max number of thread-groups per channel that would have
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
// launching kernels with nothing to do.
int B_reduced = (B + images_per_thread_block - 1) / images_per_thread_block;
if (grid_dim_y > B_reduced)
grid_dim_y = B_reduced;
int shared_mem_numel = 2 * N + 3;
int shared_mem_numel = 2 * N,
num_blocks_batch = (B + batches_per_block - 1) / B;
std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N
<< ", images_per_thread_block = " << images_per_thread_block
<< ", grid_dim_y = " << grid_dim_y
<< "\n";
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= T ||
images_per_thread_block == 1,
"Code error");
dim3 gridDim(C, num_blocks_batch, 1);
dim3 gridDim(C, grid_dim_y, 1);
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel", ([&] {
learned_nonlin_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 3>(),
params.packed_accessor32<scalar_t, 2>(),
output.packed_accessor32<scalar_t, 3>());
output.packed_accessor32<scalar_t, 3>(),
images_per_thread_block);
}));
return output;
}
......
......@@ -20,8 +20,18 @@ def test_learned_nonlin_basic():
print("params = ", params)
print("x.shape = ", x.shape)
y = learned_nonlin(x, params, dim = 1)
print("y = ", y)
if torch.cuda.is_available():
# test that the CUDA forward is the same as the CPU forward.
device = torch.device('cuda:0')
y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
assert(0);
y.sum().backward()
print("x.grad = ", x.grad)
......@@ -47,6 +57,16 @@ def test_learned_nonlin_deriv():
print(f"B,C,T,K = {B},{C},{T},{K}")
y = learned_nonlin(x, params, dim = 1)
if torch.cuda.is_available():
# test that the CUDA forward is the same as the CPU forward.
device = torch.device('cuda:0')
y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
assert(0)
y_deriv = torch.rand_like(y)
y.backward(gradient=y_deriv)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment