Commit 1dc0d559 authored by Daniel Povey's avatar Daniel Povey
Browse files

Make the forward code for CPU/CUDA give the same result

parent 8f096574
...@@ -61,6 +61,14 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile, ...@@ -61,6 +61,14 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
of the linear parts of the piecewise linear function. of the linear parts of the piecewise linear function.
The discontinuities of the function are at: The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ] exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
output: The transformed input, shape (B , C, T)
images_per_thread_block: The number of images processed by each thread
block. The calling code must guarantee that this is a power
of 2, and that EITHER:
THREADS_PER_BLOCK / images_per_thread_block >= T
OR
images_per_thread_block == 1
.. this is used for a small optimization.
This kernel is allocated with `extern_buf` containing enough memory This kernel is allocated with `extern_buf` containing enough memory
to store 2*N values of type scalar_t. to store 2*N values of type scalar_t.
...@@ -74,8 +82,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile, ...@@ -74,8 +82,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
When we invoke this kernel, we'll invoke it as: When we invoke this kernel, we'll invoke it as:
learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>> learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`: where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * numel, where bytesShared = sizeof(shared_t) * (2N + 3)
numel = 2 * (kH * kW) + max(blockDim.x, (opatchH + kH - 1) * (patchW + kW - 1))
*/ */
extern __shared__ int extern_buf[]; extern __shared__ int extern_buf[];
...@@ -84,8 +91,8 @@ __global__ ...@@ -84,8 +91,8 @@ __global__
void learned_nonlin_kernel( void learned_nonlin_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> input, // B, C, T, i.e. batch, channels, time torch::PackedTensorAccessor32<scalar_t, 3> input, // B, C, T, i.e. batch, channels, time
torch::PackedTensorAccessor32<scalar_t, 2> params, // C, N + 1 torch::PackedTensorAccessor32<scalar_t, 2> params, // C, N + 1
torch::PackedTensorAccessor32<scalar_t, 3> output) { // B, C, T torch::PackedTensorAccessor32<scalar_t, 3> output,
int images_per_thread_block) { // B, C, T
const int B = input.size(0), const int B = input.size(0),
C = input.size(1), C = input.size(1),
...@@ -93,34 +100,85 @@ void learned_nonlin_kernel( ...@@ -93,34 +100,85 @@ void learned_nonlin_kernel(
N = params.size(1) - 1, N = params.size(1) - 1,
K = N / 2; // Note: N and K are powers fo 2, with K >= 1. K = N / 2; // Note: N and K are powers fo 2, with K >= 1.
const int c = blockIdx.x, // c is channel index const int c = blockIdx.x; // c is channel index
b = blockIdx.y; // b is batch index; we'll iterate over b
scalar_t *y_vals_buf = (scalar_t*) extern_buf, // [N] scalar_t *y_vals = (scalar_t*) extern_buf, // [N], actually there are two
*params_buf = (scalar_t*) y_vals_buf + N; // [N] // spaces between here and
// `params_buf` for storing scale
// and inv_scale.
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Caution: contains params[c][1] through params[c][N].
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
// Load parameters // Load parameters
for (int n = threadIdx.x; n < N + 1; n += THREADS_PER_BLOCK) { for (int n = threadIdx.x; n <= N; n += THREADS_PER_BLOCK) {
params_buf[n - 1] = params[c][n]; params_buf[n - 1] = params[c][n];
} }
if (threadIdx.x == 0) { __syncthreads();
scalar_t scale = exp(params_buf[-1]), // The easiest way to understand this code is to compare it with the CPU code
inv_scale = 1.0 / scale; // in learned_nonlin_cpu.cpp.
params_buf[-1] = scale; if ((((int)threadIdx.x & ~(int)32)) == 0) {
params_buf[-2] = inv_scale; // threadIdx.x == 0 or 32. These are in separate warps so we can
} else if (threadIdx.x & ~96 == 0) {
// threadIdx.x == 32 or 64. These, and 0, are in separate warps so we can
// allow them to do separate jobs. This code takes linear time in K which // allow them to do separate jobs. This code takes linear time in K which
// is not at all ideal and could be improved if K is largish, but it shouldn't // is not at all ideal and could be improved if K is largish, but it shouldn't
// dominate the total time taken if we are processing a lot of data; // dominate the total time taken if we are processing a lot of data;
// and anyway, we doubt that K will be need to be more than 4 or 8 or so, // and anyway, we doubt that K will be need to be more than 4 or 8 or so,
// so the potential savings are quite small. // so the potential savings are quite small.
scalar_t scale = exp(params_buf[-1]),
inv_scale = 1.0 / scale;
params_buf[-2] = scale; // both threads write these but it's OK, it's the
// same value.
params_buf[-3] = inv_scale;
int sign,
Koffset; // Koffset == K for threads handling sum_positive and K - 1
// for threads handling sum_negative, see
// learned_nonlin_cpu.cpp for reference code. This would be K
// + 1 and K respectively, except our params_buf has its index
// shifted by one versus params.
if (threadIdx.x == 0) { // sum_positive
sign = 1;
Koffset = K;
} else { // threadIdx.x == 32. sum_negative.
scale *= -1; // this is a local variable..
sign = -1;
Koffset = K - 1;
}
scalar_t sum = 0.0;
for (int i = 0; i < K; i++) {
int isign = i * sign;
y_vals[K + isign] = sum * scale;
printf("c = %d, y_vals[%d] = %f\n", c, K + isign, sum * scale);
sum += params_buf[Koffset + isign];
}
y_vals[0] = y_vals[1]; // Both threads do this but it's OK.
} }
__syncthreads(); __syncthreads();
scalar_t scale = params_buf[-1], scalar_t inv_scale = params_buf[-3];
inv_scale = params_buf[-2];
int T_inc = THREADS_PER_BLOCK / images_per_thread_block,
image_offset = threadIdx.x / T_inc,
t_start = threadIdx.x % T_inc;
for (int b = blockIdx.y * images_per_thread_block + image_offset;
b < B; b += gridDim.y * images_per_thread_block) {
// We do "t += THREADS_PER_BLOCK" instead of t += (THREADS_PER_BLOCK /
// images_per_thread_block) as a small optimization because the only case we
// really need to loop is when images_per_thread_block == 1:a we only let
// images_per_thread_block > 1 if T * images_per_thread_block <=
// THREADS_PER_BLOCK.
for (int t = t_start; t < T; t += THREADS_PER_BLOCK) {
scalar_t x = input[b][c][t] * inv_scale + K;
int min = 0, diff = K;
while (diff > 0) {
int mid = min + diff;
if (x >= mid)
min = mid;
diff = diff >> 1;
}
// OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - (scalar_t)min) * params_buf[min] + y_vals[min];
output[b][c][t] = y;
}
}
} }
...@@ -197,8 +255,6 @@ void learned_nonlin_kernel( ...@@ -197,8 +255,6 @@ void learned_nonlin_kernel(
bytesShared = sizeof(shared_t) * numel, where bytesShared = sizeof(shared_t) * numel, where
numel = 4 * (kH * kW) + 3 * (ppatchH * ppatchW) + blockDim.x numel = 4 * (kH * kW) + 3 * (ppatchH * ppatchW) + blockDim.x
*/ */
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void learned_nonlin_kernel_backward( void learned_nonlin_kernel_backward(
...@@ -502,35 +558,51 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input, ...@@ -502,35 +558,51 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
auto scalar_t = input.scalar_type(); auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device()); auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor output = torch::empty({B, C, T}, opts); // TODO: make this empty
torch::Tensor output = torch::ones({B, C, T}, opts);
if (C * B * T == 0) if (C * B * T == 0)
return output; return output;
// The number of thread blocks is at least C (the number of channels), but int images_per_thread_block = 1;
// if the number of channels is small we may split further on the batch. while (images_per_thread_block * 2 * T <= THREADS_PER_BLOCK)
images_per_thread_block *= 2;
int batches_per_block = 1; int grid_dim_y = 1;
if (C * batches_per_block < 128) { // If the number of channels is quite small (<128) we can launch more thread
// Aim for at least 128 thread blocks.. // groups, splitting on the batch index.
batches_per_block = 128 / C; while (C * grid_dim_y < 128)
if (batches_per_block > B) grid_dim_y *= 2;
batches_per_block = B;
} // B_reduced is the max number of thread-groups per channel that would have
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
// launching kernels with nothing to do.
int B_reduced = (B + images_per_thread_block - 1) / images_per_thread_block;
if (grid_dim_y > B_reduced)
grid_dim_y = B_reduced;
int shared_mem_numel = 2 * N + 3;
int shared_mem_numel = 2 * N, std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N
num_blocks_batch = (B + batches_per_block - 1) / B; << ", images_per_thread_block = " << images_per_thread_block
<< ", grid_dim_y = " << grid_dim_y
<< "\n";
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= T ||
images_per_thread_block == 1,
"Code error");
dim3 gridDim(C, num_blocks_batch, 1);
dim3 gridDim(C, grid_dim_y, 1);
// blockDim is scalar, just THREADS_PER_BLOCK. // blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel", ([&] {
learned_nonlin_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>( learned_nonlin_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 3>(), input.packed_accessor32<scalar_t, 3>(),
params.packed_accessor32<scalar_t, 2>(), params.packed_accessor32<scalar_t, 2>(),
output.packed_accessor32<scalar_t, 3>()); output.packed_accessor32<scalar_t, 3>(),
images_per_thread_block);
})); }));
return output; return output;
} }
......
...@@ -20,8 +20,18 @@ def test_learned_nonlin_basic(): ...@@ -20,8 +20,18 @@ def test_learned_nonlin_basic():
print("params = ", params) print("params = ", params)
print("x.shape = ", x.shape) print("x.shape = ", x.shape)
y = learned_nonlin(x, params, dim = 1) y = learned_nonlin(x, params, dim = 1)
print("y = ", y) print("y = ", y)
if torch.cuda.is_available():
# test that the CUDA forward is the same as the CPU forward.
device = torch.device('cuda:0')
y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
assert(0);
y.sum().backward() y.sum().backward()
print("x.grad = ", x.grad) print("x.grad = ", x.grad)
...@@ -47,6 +57,16 @@ def test_learned_nonlin_deriv(): ...@@ -47,6 +57,16 @@ def test_learned_nonlin_deriv():
print(f"B,C,T,K = {B},{C},{T},{K}") print(f"B,C,T,K = {B},{C},{T},{K}")
y = learned_nonlin(x, params, dim = 1) y = learned_nonlin(x, params, dim = 1)
if torch.cuda.is_available():
# test that the CUDA forward is the same as the CPU forward.
device = torch.device('cuda:0')
y2 = learned_nonlin(x.to(device), params.to(device), dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
assert(0)
y_deriv = torch.rand_like(y) y_deriv = torch.rand_like(y)
y.backward(gradient=y_deriv) y.backward(gradient=y_deriv)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment