Commit 8f096574 authored by Daniel Povey's avatar Daniel Povey
Browse files

Optimization of CPU code; start drafting forward code for CUDA

parent d6081b04
...@@ -34,10 +34,12 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -34,10 +34,12 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
y_vals_a = y_vals.accessor<scalar_t, 2>(); y_vals_a = y_vals.accessor<scalar_t, 2>();
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0, scalar_t sum_negative = 0.0,
sum_positive = 0.0; sum_positive = 0.0,
scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals_a[c][K + i] = sum_positive; y_vals_a[c][K + i] = sum_positive * scale;
y_vals_a[c][K - i] = sum_negative; y_vals_a[c][K - i] = sum_negative * scale;
sum_positive += params_a[c][1 + K + i]; sum_positive += params_a[c][1 + K + i];
sum_negative -= params_a[c][K - i]; sum_negative -= params_a[c][K - i];
} }
...@@ -51,9 +53,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -51,9 +53,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
scalar_t l = params_a[c][0], scalar_t inv_scale = exp(-params_a[c][0]);
scale = exp(l),
inv_scale = 1.0 / scale;
for (int t = 0; t < T; t++) { for (int t = 0; t < T; t++) {
// `x` is the scaled input x plus an offset so that -K maps to 0. // `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1), // Note: the discontinuities in our function are at -(K-1) ... +(K+1),
...@@ -72,7 +72,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input, ...@@ -72,7 +72,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min]; scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min, // printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]); // x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
output_a[b][c][t] = y * scale; output_a[b][c][t] = y;
} }
} }
}})); }}));
...@@ -120,12 +120,13 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -120,12 +120,13 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
y_vals_grad_a = y_vals_grad.accessor<scalar_t, 2>(); y_vals_grad_a = y_vals_grad.accessor<scalar_t, 2>();
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0, scalar_t sum_negative = 0.0,
sum_positive = 0.0; sum_positive = 0.0,
scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
y_vals_a[c][K + i] = sum_positive; y_vals_a[c][K + i] = sum_positive;
y_vals_a[c][K - i] = sum_negative; y_vals_a[c][K - i] = sum_negative;
sum_positive += params_a[c][1 + K + i]; sum_positive += params_a[c][1 + K + i] * scale;
sum_negative -= params_a[c][K - i]; sum_negative -= params_a[c][K - i] * scale;
} }
// the reference point for the lowest, half-infinite interval (the one // the reference point for the lowest, half-infinite interval (the one
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals. // starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals.
...@@ -138,10 +139,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -138,10 +139,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
scalar_t l = params_a[c][0], scalar_t inv_scale = exp(-params_a[c][0]),
scale = exp(l),
inv_scale = 1.0 / scale,
scale_grad = 0.0,
inv_scale_grad = 0.0; inv_scale_grad = 0.0;
for (int t = 0; t < T; t++) { for (int t = 0; t < T; t++) {
// `x` is the scaled input x plus an offset so that -K maps to 0. // `x` is the scaled input x plus an offset so that -K maps to 0.
...@@ -151,7 +149,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -151,7 +149,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
// that are < -(K-1) and > (K-1) // that are < -(K-1) and > (K-1)
scalar_t input = input_a[b][c][t], scalar_t input = input_a[b][c][t],
x = input * inv_scale + K, x = input * inv_scale + K,
output_grad = output_grad_a[b][c][t]; y_grad = output_grad_a[b][c][t];
int min = 0, diff = K; int min = 0, diff = K;
while (diff > 0) { while (diff > 0) {
int mid = min + diff; int mid = min + diff;
...@@ -160,10 +158,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -160,10 +158,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
diff = diff >> 1; diff = diff >> 1;
} }
// OK, at this point, 0 <= min < 2*K. // OK, at this point, 0 <= min < 2*K.
scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min];
// backprop for: output_a[b][c][t] = y * scale;
scale_grad += y * output_grad;
scalar_t y_grad = scale * output_grad;
// backprop for: // backprop for:
// scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min]; // scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min];
scalar_t x_grad = y_grad * params_a[c][min + 1]; scalar_t x_grad = y_grad * params_a[c][min + 1];
...@@ -173,29 +167,36 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input, ...@@ -173,29 +167,36 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
inv_scale_grad += x_grad * input; inv_scale_grad += x_grad * input;
input_grad_a[b][c][t] = x_grad * inv_scale; input_grad_a[b][c][t] = x_grad * inv_scale;
} }
// Do the backprop to l as if we had done: // Do the backprop for: inv_scale = exp(-params_a[c][0])
// scale = exp(l); inv_scale = exp(-l); params_grad_a[c][0] -= inv_scale * inv_scale_grad;
scalar_t l_grad = scale * scale_grad - inv_scale * inv_scale_grad;
params_grad_a[c][0] += l_grad;
} }
} }
// Now do the backprop for the loop above where we set y_vals_a. // Now do the backprop for the loop above where we set y_vals_a.
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
// backprop for: y_vals_a[c][0] = y_vals_a[c][1]; // backprop for: y_vals_a[c][0] = y_vals_a[c][1];
y_vals_grad_a[c][1] += y_vals_grad_a[c][0]; y_vals_grad_a[c][1] += y_vals_grad_a[c][0];
scalar_t sum_negative_grad = 0.0, scalar_t scale = exp(params_a[c][0]),
inv_scale = 1.0 / scale,
scale_grad = 0.0,
sum_negative_grad = 0.0,
sum_positive_grad = 0.0; sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) { for (int i = K - 1; i >= 0; i--) {
// backprop for: sum_negative -= params_a[c][K - i]; // backprop for: sum_negative -= params_a[c][K - i];
params_grad_a[c][K - i] -= sum_negative_grad; params_grad_a[c][K - i] -= sum_negative_grad;
// backprop for: sum_positive += params_a[c][1 + K + i]; // backprop for: sum_positive += params_a[c][1 + K + i] * scale;
params_grad_a[c][1 + K + i] += sum_positive_grad; params_grad_a[c][1 + K + i] += sum_positive_grad;
// backprop for: y_vals_a[c][K - i] = sum_negative; // backprop for: y_vals_a[c][K - i] = sum_negative * scale;
sum_negative_grad += y_vals_grad_a[c][K - i]; sum_negative_grad += y_vals_grad_a[c][K - i] * scale;
// backprop for: y_vals_a[c][K + i] = sum_positive; // The next code line is equivalent to:
sum_positive_grad += y_vals_grad_a[c][K + i]; // scale_grad += y_vals_grad_a[c][K - i] * sum_negative, substituting:
// sum_negative == y_vals_a[c][K - i] / scale
scale_grad += y_vals_grad_a[c][K - i] * y_vals_a[c][K - i] * inv_scale;
// backprop for: y_vals_a[c][K + i] = sum_positive * scale;
sum_positive_grad += y_vals_grad_a[c][K + i] * scale;
scale_grad += y_vals_grad_a[c][K + i] * y_vals_a[c][K + i] * inv_scale;
} }
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a[c][0] += scale * scale_grad;
} }
})); }));
return std::vector<torch::Tensor>({input_grad, params_grad}); return std::vector<torch::Tensor>({input_grad, params_grad});
......
...@@ -4,18 +4,17 @@ ...@@ -4,18 +4,17 @@
// forward of learned_nonlin. """... """ comment of `learned_nonlin` // forward of learned_nonlin. """... """ comment of `learned_nonlin`
// in learned_nonlin.py documents the behavior of this function. // in learned_nonlin.py documents the behavior of this function.
torch::Tensor learned_nonlin_cuda(torch::Tensor input, torch::Tensor learned_nonlin_cuda(torch::Tensor input,
torch::Tensor pos_add, torch::Tensor params);
torch::Tensor pos_mul);
// backward of learned_nonlin; returns (grad_input, grad_pos_add, grad_pos_mul).
// backward of learned_nonlin; returns (grad_input, grad_params).
std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input, std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
torch::Tensor pos_add, torch::Tensor params,
torch::Tensor pos_mul, torch::Tensor grad_output);
torch::Tensor grad_output);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("learned_nonlin_cuda", &learned_nonlin_cuda, "Integrated convolution forward function (CUDA)"); m.def("learned_nonlin_cuda", &learned_nonlin_cuda, "Learned nonlinearity forward function (CUDA)");
m.def("learned_nonlin_backward_cuda", &learned_nonlin_backward_cuda, "Integrated convolution backward function (CUDA)"); m.def("learned_nonlin_backward_cuda", &learned_nonlin_backward_cuda, "Learned nonlinearity backward function (CUDA)");
} }
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <cooperative_groups.h> #include <cooperative_groups.h>
#define THREADS_PER_BLOCK 256
/* /*
...@@ -39,33 +41,36 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile, ...@@ -39,33 +41,36 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
// return the full sums of their tiles. // return the full sums of their tiles.
} }
/* /*
Forward of learned_nonlin. Each thread group handles a single channel (equal Forward of learned_nonlin. Each thread group handles a single channel (equal
to blockIdx.x), and loops over patches of the output and over the image n to blockIdx.x); the gridDim is (C, nb) where 1 <= nb <= B (nb relates to the batch).
within the batch (different thread groups may be responsible for different
subsets of patches and/or images, see docs of gridDim below).
Template args: Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half. scalar_t: the floating-point type, e.g. float, double, maybe half.
Args: Args:
input: input image, shape (N, 2*C, H, W) input: input image, shape (B, C, T) where B is batch size, C is
pos_add: positional encoding, additive part, shape (C, kH, kW) the number of channels and T is the time axis. (For more-than-1d
pos_mul: positional encoding, multiplicative part, shape (C, kH, kW) convolution setups, T would really be more than 1 axis, reshaped).
output: output image, shape (N, 2*C, H, W) params: of shape (C, N+1) where N is the number of linear regions in the
Note: kH and kW must both be odd so that it's clear how to zero-pad. piecewise linear function; params[c][0] is l which is
a log scale parameter that dictates how far apart
The thread-block should have one dimension (x); blockDim.x should equal the discontinuities in the piecewise linear function are,
some small power of 2 (threads_per_opixel) times the output-patch size which is and params[c][n+1] for 0 <= n < N are the derivatives
opatchH * opatchW (the output-patch height and width). We expect of the linear parts of the piecewise linear function.
threads_per_opixel to be 1, 2, or 4; we use a linear summation to sum up the The discontinuities of the function are at:
different threads' partial sums, and if threads_per_opixel gets larger we'd exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
need to make this a logarithmic reduction.
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N values of type scalar_t.
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
The requirements on the grid dimension are: The requirements on the grid dimension are:
gridDim.x == num-channels C (required) gridDim.x == num-channels C (required)
gridDim.y <= num-patches per image (recommended) 1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z <= batch-size N (recommended) gridDim.z == 1
When we invoke this kernel, we'll invoke it as: When we invoke this kernel, we'll invoke it as:
learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>> learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`: where bytesShared is the number of bytes needed in `extern_buf`:
...@@ -77,150 +82,46 @@ extern __shared__ int extern_buf[]; ...@@ -77,150 +82,46 @@ extern __shared__ int extern_buf[];
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void learned_nonlin_kernel( void learned_nonlin_kernel(
torch::PackedTensorAccessor32<scalar_t, 4> input, // N, 2*C, H, W torch::PackedTensorAccessor32<scalar_t, 3> input, // B, C, T, i.e. batch, channels, time
torch::PackedTensorAccessor32<scalar_t, 3> pos_add, // C, kH, kW torch::PackedTensorAccessor32<scalar_t, 2> params, // C, N + 1
torch::PackedTensorAccessor32<scalar_t, 3> pos_mul, // C, kH, kW torch::PackedTensorAccessor32<scalar_t, 3> output) { // B, C, T
torch::PackedTensorAccessor32<scalar_t, 4> output, // N, C, H, W
int opatchH, // output-patch height,
int opatchW // output-patch width
) {
const int H = input.size(2),
W = input.size(3),
kH = pos_add.size(1),
kW = pos_add.size(2),
npatchH = (H + opatchH - 1) / opatchH, // num patches in vertical dim
npatchW = (W + opatchW - 1) / opatchW, // num patches in horizontal dim
npatch = npatchH * npatchW; // total number of patches per image
// Channel index.
const int c = blockIdx.x;
// We don't need to check the range of `c` because we set gridDim.x to the
// exact number of channels.
const int ipatchH = opatchH + kH - 1,
ipatchW = opatchW + kW - 1,
ipatch_size = ipatchH * ipatchW,
opatch_size = opatchH * opatchW;
// `extern_buf` is general-purpose shared memory, which we'll divide between
// pos_add, pos_mul and src_img_buf to be shared between the src image size
// (ipatch_size) and the number of threads (blockDim.x)
// these are pointers to __shared__ memory; the compiler should
// be able to figure this out.
scalar_t
*pos_add_buf = (scalar_t*)extern_buf, // pos_add positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
*pos_mul_buf = pos_add_buf + (kH * kW), // pos_mul positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
*src_img_buf = pos_mul_buf + (kH * kW); // version of input image that relates to source position,
// of size [ipatch_size], indexed [h*ipatchW + w]...
// note, the 'h' and 'w' indexes are into the zero-padded input
// image.
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2; // Note: N and K are powers fo 2, with K >= 1.
int threads_per_opixel = blockDim.x / opatch_size; const int c = blockIdx.x, // c is channel index
assert(blockDim.x == opatch_size * threads_per_opixel); b = blockIdx.y; // b is batch index; we'll iterate over b
// pos_in_patch will be interpreted as h_in_patch * opatchW + w_in_patch.
int pos_in_patch = threadIdx.x / threads_per_opixel;
// Load parts of the kernel parameters pos_add and pos_mul into shared memory, scalar_t *y_vals_buf = (scalar_t*) extern_buf, // [N]
// in pos_add_buf and pos_mul_buf *params_buf = (scalar_t*) y_vals_buf + N; // [N]
for (int i = threadIdx.x; i < kH * kW; i += blockDim.x) { // Load parameters
int kh = i / kW, for (int n = threadIdx.x; n < N + 1; n += THREADS_PER_BLOCK) {
kw = i % kW; params_buf[n - 1] = params[c][n];
pos_add_buf[i] = pos_add[c][kh][kw];
pos_mul_buf[i] = pos_mul[c][kh][kw];
} }
if (threadIdx.x == 0) {
scalar_t scale = exp(params_buf[-1]),
inv_scale = 1.0 / scale;
params_buf[-1] = scale;
params_buf[-2] = inv_scale;
} else if (threadIdx.x & ~96 == 0) {
// threadIdx.x == 32 or 64. These, and 0, are in separate warps so we can
// allow them to do separate jobs. This code takes linear time in K which
// is not at all ideal and could be improved if K is largish, but it shouldn't
// dominate the total time taken if we are processing a lot of data;
// and anyway, we doubt that K will be need to be more than 4 or 8 or so,
// so the potential savings are quite small.
// n is the index within the batch. Loop to make sure we cover all images in
// the batch. input.size(0) is the batch size N. All threads in the thread-block
// loop the same number of times.
for (int n = blockIdx.z; n < input.size(0); n += gridDim.z) {
// Loop over the patch within the output image. All threads in the
// thread-block loop the same number of times.
for (int patch_idx = blockIdx.y; patch_idx < npatch; patch_idx += gridDim.y) {
// (patch_h_offset, patch_w_offset) are the (vertical, horizontal) indexes
// of the lowest-numbered pixel in the patch of output that this thread
// block is responsible for.
int patch_h_offset = (patch_idx / npatchW) * opatchH,
patch_w_offset = (patch_idx % npatchW) * opatchW;
// This __syncthreads() is only necessary if we have already looped at
// least once over n or patch_idx: it's in case other threads are still
// using the `src_img_buf` buffer for something else.
__syncthreads();
// Load the 'src' part of the input patch; the size of this is the size of
// the output patch plus a border of sizes kH//2, kW//2 on each side.
for (int i = threadIdx.x; i < ipatch_size; i += blockDim.x) {
int h_in_kernel = i / ipatchW,
w_in_kernel = i % ipatchW;
int src_h = patch_h_offset + h_in_kernel - (kH / 2), // kH / 2 is offset due to padding
src_w = patch_w_offset + w_in_kernel - (kW / 2);
scalar_t src_val = scalar_t(0);
if ((unsigned int)src_h < (unsigned int)H && // h >= 0 && h < H
(unsigned int)src_w < (unsigned int)W) // w >= 0 && w < W
src_val = input[n][c][src_h][src_w];
src_img_buf[i] = src_val;
}
// make sure all threads have written to `src_img_buf`
__syncthreads();
// 'h' and 'w' are the positions within the output image, that this tile
// of size threads_per_opixel is responsible for.
int h = patch_h_offset + pos_in_patch / opatchW,
w = patch_w_offset + pos_in_patch % opatchW;
// The "destination" pixel; this is an input. It gets added to each
// src pixel, prior to the relu, in the loop below.
scalar_t dest_val = scalar_t(0);
if (h < H && w < W) {
// Several threads (within the same tile, which implies the same warp)
// may load the same value here, but I believe the device's memory
// subsystem handles this well enough that we can just ignore the issue
// rather than try to optimize it.
// https://forums.developer.nvidia.com/t/accessing-same-global-memory-address-within-warps/66574
int C = input.size(1) / 2;
dest_val = input[n][c + C][h][w]; // else 0.
}
// `sum` is the partial sum that this thread computes; we'll sum this over
// the `threads_per_opixel` threads in the tile to get the output pixel
// value.
scalar_t sum = 0.0;
for (int pos_in_kernel = threadIdx.x % threads_per_opixel;
pos_in_kernel < (kH * kW);
pos_in_kernel += threads_per_opixel) {
int h_in_kernel = pos_in_kernel / kW,
w_in_kernel = pos_in_kernel % kW;
// Note: this is actually more like cross-correlation, as we don't
// have a negative sign on the h and w indexes in the kernel.
// Also note: we already took care of padding and the associated
// offsets of -(kH / 2) and -(kW / 2).
int h_in_src_patch = (pos_in_patch / opatchW) + h_in_kernel,
w_in_src_patch = (pos_in_patch % opatchW) + w_in_kernel;
scalar_t src_val = src_img_buf[h_in_src_patch * ipatchW + w_in_src_patch],
pos_add_val = pos_add_buf[pos_in_kernel];
scalar_t relu = (src_val + dest_val + pos_add_val);
if (relu > 0.0)
sum += relu * pos_mul_buf[pos_in_kernel];
}
// Sync threads because src_img_buf is also used above.
__syncthreads();
// Aggregate `sum` over threads
sum = tiled_warp_reduce_sum(threads_per_opixel, src_img_buf, sum);
if (threadIdx.x % threads_per_opixel == 0 && h < H && w < W) {
output[n][c][h][w] = sum;
}
}
} }
__syncthreads();
scalar_t scale = params_buf[-1],
inv_scale = params_buf[-2];
} }
...@@ -578,117 +479,58 @@ void learned_nonlin_kernel_backward( ...@@ -578,117 +479,58 @@ void learned_nonlin_kernel_backward(
torch::Tensor learned_nonlin_cuda(torch::Tensor input,
torch::Tensor params) {
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
torch::Tensor learned_nonlin_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul) {
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor"); TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor");
const int N = input.size(0), TORCH_CHECK(params.device().is_cuda(), "Params must be a CUDA tensor");
C = input.size(1) / 2,
H = input.size(2),
W = input.size(3),
kH = pos_add.size(1),
kW = pos_add.size(2);
TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
pos_mul.size(1) == kH && pos_mul.size(2) == kW,
"Input sizes mismatch.");
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
auto scalar_t = input.scalar_type();
TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch");
torch::Tensor output = torch::empty({N, C, H, W},
torch::TensorOptions().dtype(scalar_t).device(input.device()));
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1;
// Work out the configuration to call the kernel with.. auto scalar_t = input.scalar_type();
int patchH = std::min(H, kH), // output patch height auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
patchW = std::min(W, kW); // output patch width
// We don't want the height or width of the patch to be less than the kernel
// width, or the padding will make the input-patch size more than twice the
// output-patch size.
// We aim for the output-patch size to be more than 128; this is not something
// very exact, but it roughly corresponds to us wanting to have up to 4 threads
// per output pixel, and the limitation of 512 threads per thread-block which
// we impose so that we can run on architectures with little shared memory.
while (patchW < W && patchH * (patchW + 1) <= 128)
patchW++;
while(patchH < H && (patchH + 1) * patchW <= 128)
patchH++;
// We are assuming that the thread-block size can be as large as 512; this torch::Tensor output = torch::empty({B, C, T}, opts);
// works even on old CUDA architectures.
int threads_per_opixel;
if (patchH * patchW * 4 <= 512 && (kH * kW) > 16)
threads_per_opixel = 4;
else if (patchH * patchW * 2 <= 512 && (kH * kW) > 8)
threads_per_opixel = 2;
else
threads_per_opixel = 1;
int input_patchH = patchH + kH - 1, if (C * B * T == 0)
input_patchW = patchW + kW - 1, return output;
input_patch_size = input_patchH * input_patchW;
int threads_per_block = patchH * patchW * threads_per_opixel; // The number of thread blocks is at least C (the number of channels), but
// if the number of channels is small we may split further on the batch.
int buffer_numel = 2 * (kH * kW) + std::max<int>(threads_per_block, int batches_per_block = 1;
input_patch_size); if (C * batches_per_block < 128) {
// Aim for at least 128 thread blocks..
batches_per_block = 128 / C;
if (batches_per_block > B)
batches_per_block = B;
}
int num_patches_H = (H + patchH - 1) / patchH, int shared_mem_numel = 2 * N,
num_patches_W = (W + patchW - 1) / patchW, num_blocks_batch = (B + batches_per_block - 1) / B;
num_patches = num_patches_H * num_patches_W;
// gridDim.x == C.
int num_blocks_patch = 1, // gridDim.y.
num_blocks_batch = 1; // gridDim.z
while (C * num_blocks_patch <= 256 &&
num_blocks_patch * 2 <= num_patches)
num_blocks_patch *= 2;
if (C * num_patches <= 512)
num_blocks_patch = num_patches;
while (C * num_blocks_patch * num_blocks_batch <= 512 &&
num_blocks_batch * 2 <= N)
num_blocks_batch *= 2;
if (C * num_blocks_patch * N <= 1024)
num_blocks_batch = N;
assert(num_blocks_patch <= num_patches && num_blocks_batch <= N); dim3 gridDim(C, num_blocks_batch, 1);
static int debug_count = 50;
if (debug_count > 0) {
debug_count--;
std::cout << "N,C,H,W=" << N << "," << C << "," << H << "," << W
<< "; kW,kH=" << kW << "," << kH
<< "; patchH,patchW=" << patchH << ","
<< patchW << ", num_blocks_patch="
<< num_blocks_patch << ", num_blocks_batch="
<< num_blocks_batch
<< ", threads_per_opixel=" << threads_per_opixel
<< ", threads_per_block=" << threads_per_block
<< std::endl;
}
dim3 gridDim(C, num_blocks_patch, num_blocks_batch); // blockDim is scalar, just THREADS_PER_BLOCK.
// blockDim is scalar, just threads_per_block.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel", ([&] {
learned_nonlin_kernel<scalar_t><<<gridDim, threads_per_block, sizeof(scalar_t) * buffer_numel, at::cuda::getCurrentCUDAStream()>>>( learned_nonlin_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 4>(), input.packed_accessor32<scalar_t, 3>(),
pos_add.packed_accessor32<scalar_t, 3>(), params.packed_accessor32<scalar_t, 2>(),
pos_mul.packed_accessor32<scalar_t, 3>(), output.packed_accessor32<scalar_t, 3>());
output.packed_accessor32<scalar_t, 4>(),
patchH,
patchW);
})); }));
return output; return output;
} }
...@@ -696,9 +538,10 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input, ...@@ -696,9 +538,10 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input, std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
torch::Tensor pos_add, torch::Tensor params,
torch::Tensor pos_mul,
torch::Tensor grad_output) { torch::Tensor grad_output) {
/*
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional"); TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional."); TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional."); TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
...@@ -856,5 +699,5 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input, ...@@ -856,5 +699,5 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
grad_pos_add = at::sum(grad_pos_add, {0}); grad_pos_add = at::sum(grad_pos_add, {0});
grad_pos_mul = at::sum(grad_pos_mul, {0}); grad_pos_mul = at::sum(grad_pos_mul, {0});
return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul}); return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul}); */
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment