Unverified Commit 85b37e7b authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Enhancment] Optimize correlation op (#1814)

* optimize forward

* fast backward

* fix bugs of grad input2
parent cff3fecc
......@@ -29,8 +29,8 @@ using namespace torch;
#define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits>
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
#define THREADS_FORWARD 32
#define THREADS_BACKWARD 16
#define WARP_SIZE 32
#define FULL_MASK 0xffffffff
template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel(
......@@ -42,8 +42,8 @@ __global__ void correlation_forward_cuda_kernel(
const int C = rInput1.size(3);
const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int h = blockIdx.y * blockDim.y + threadIdx.y;
const int w = blockIdx.z * blockDim.z + threadIdx.z;
const int thread = threadIdx.x;
const int start_i = -padH + h * dH;
......@@ -52,13 +52,11 @@ __global__ void correlation_forward_cuda_kernel(
const int patchRadH = dilation_patchH * (patchH - 1) / 2;
const int patchRadW = dilation_patchW * (patchW - 1) / 2;
__shared__ scalar_t prod_sum[THREADS_FORWARD];
for (int ph = 0; ph < patchH; ++ph) {
int ph_dilated = ph * dilation_patchH - patchRadH;
for (int pw = 0; pw < patchW; ++pw) {
int pw_dilated = pw * dilation_patchW - patchRadW;
prod_sum[thread] = 0;
scalar_t prod_sum = 0.0f;
for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated;
......@@ -69,23 +67,20 @@ __global__ void correlation_forward_cuda_kernel(
int j2 = j1 + pw_dilated;
if
WITHIN_BOUNDS(j1, j2, iW, iW) {
for (int c = thread; c < C; c += THREADS_FORWARD) {
for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum[thread] += v1 * v2;
prod_sum += v1 * v2;
}
}
}
}
}
// accumulate
__syncthreads();
for (int offset = 16; offset > 0; offset /= 2)
prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset);
if (thread == 0) {
scalar_t reduce_sum = 0;
for (int index = 0; index < THREADS_FORWARD; ++index) {
reduce_sum += prod_sum[index];
}
output[n][ph][pw][h][w] = reduce_sum;
output[n][ph][pw][h][w] = prod_sum;
}
}
}
......@@ -97,9 +92,10 @@ __global__ void correlation_backward_cuda_kernel_input1(
TensorAcc4R grad_input1, const int kH, const int kW, const int patchH,
const int patchW, const int padH, const int padW, const int dilationH,
const int dilationW, const int dilation_patchH, const int dilation_patchW,
const int dH, const int dW, const int batch) {
const int iH = input2.size(2);
const int iW = input2.size(3);
const int dH, const int dW) {
const int iH = input2.size(1);
const int iW = input2.size(2);
const int C = input2.size(3);
const int H = grad_output.size(3);
const int W = grad_output.size(4);
......@@ -107,54 +103,53 @@ __global__ void correlation_backward_cuda_kernel_input1(
const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2;
const int n = batch;
const int c = blockIdx.x;
const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int ph_off = threadIdx.x;
const int pw_off = threadIdx.y;
const int h_2 = h + padH;
const int w_2 = w + padW;
const int min_h = h_2 - kH * dilationH;
const int min_w = w_2 - kW * dilationW;
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
scalar_t *grad_cache = reinterpret_cast<scalar_t *>(grad_cache_char);
for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
const int ph = i / patchW;
const int pw = i % patchW;
int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t val = input2[n][c][i1][j1];
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if
WITHIN_BOUNDS(i2, j2, H, W) {
prod_sum[ph_off][pw_off] +=
grad_output[n][ph][pw][i2][j2] * val;
}
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t grad_val = 0.0f;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if (WITHIN_BOUNDS(i2, j2, H, W)) {
grad_val += grad_output[n][ph][pw][i2][j2];
}
}
}
grad_cache[i] = grad_val;
}
}
__syncthreads();
if (ph_off == 0 && pw_off == 0) {
scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
reduce_sum += prod_sum[ph][pw];
for (int c = threadIdx.x; c < C; c += blockDim.x) {
scalar_t grad_input_val = 0.0f;
for (int ph = 0; ph < patchH; ++ph) {
int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = 0; pw < patchW; ++pw) {
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
grad_input_val += input2[n][i1][j1][c] * grad_cache[ph * patchW + pw];
}
}
}
grad_input1[n][c][h][w] = reduce_sum;
grad_input1[n][c][h][w] = grad_input_val;
}
}
......@@ -163,9 +158,10 @@ __global__ void correlation_backward_cuda_kernel_input2(
const TensorAcc5R grad_output, const TensorAcc4R input1,
TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW, int batch) {
const int iH = input1.size(2);
const int iW = input1.size(3);
int dilation_patchW, int dH, int dW) {
const int iH = input1.size(1);
const int iW = input1.size(2);
const int C = input1.size(3);
const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2;
......@@ -176,56 +172,54 @@ __global__ void correlation_backward_cuda_kernel_input2(
const int dilatedKH = kH * dilationH;
const int dilatedKW = kW * dilationW;
const int n = batch;
const int c = blockIdx.x;
const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int ph_off = threadIdx.x;
const int pw_off = threadIdx.y;
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
scalar_t *grad_cache = reinterpret_cast<scalar_t *>(grad_cache_char);
for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
const int ph = i / patchW;
const int pw = i % patchW;
int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
int j1 = w - dilation_patchW * (pw - patchRadW);
if
WITHIN_BOUNDS(i1, j1, iH, iW) {
scalar_t val = input1[n][c][i1][j1];
const int h_2 = i1 + padH;
const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if
WITHIN_BOUNDS(i2, j2, H, W) {
prod_sum[ph_off][pw_off] +=
grad_output[n][ph][pw][i2][j2] * val;
}
}
int j1 = w - dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t grad_val = 0.0f;
const int h_2 = i1 + padH;
const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if (WITHIN_BOUNDS(i2, j2, H, W)) {
grad_val += grad_output[n][ph][pw][i2][j2];
}
}
}
grad_cache[i] = grad_val;
}
}
__syncthreads();
if (ph_off == 0 && pw_off == 0) {
scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
reduce_sum += prod_sum[ph][pw];
for (int c = threadIdx.x; c < C; c += blockDim.x) {
scalar_t grad_input_val = 0.0f;
for (int ph = 0; ph < patchH; ++ph) {
int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = 0; pw < patchW; ++pw) {
int j1 = w - dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
grad_input_val += input1[n][i1][j1][c] * grad_cache[ph * patchW + pw];
}
}
}
grad_input2[n][c][h][w] = reduce_sum;
grad_input2[n][c][h][w] = grad_input_val;
}
}
#endif
......@@ -24,8 +24,8 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const int threads = THREADS_FORWARD;
const dim3 blocks(batch_size, oH, oW);
const dim3 threads(WARP_SIZE, 4, 4);
const dim3 blocks(batch_size, (oH + 3) >> 2, (oW + 3) >> 2);
at::cuda::CUDAGuard device_guard(input1.device());
......@@ -56,17 +56,20 @@ void CorrelationBackwardCUDAKernelLauncher(
const int iW = input1.size(3);
const int C = input1.size(1);
const dim3 blocks(C, iH, iW);
const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD);
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const dim3 blocks(batch_size, iH, iW);
const dim3 threads(THREADS_PER_BLOCK);
at::cuda::CUDAGuard device_guard(input1.device());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input1.scalar_type(), "correlation_backward_cuda", ([&] {
const int grad_cache_size = patchH * patchW * sizeof(scalar_t);
TensorAcc4R input1_acc =
input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
trInput1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R input2_acc =
input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
trInput2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input1_acc =
grad_input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input2_acc =
......@@ -74,20 +77,18 @@ void CorrelationBackwardCUDAKernelLauncher(
TensorAcc5R grad_output_acc =
grad_output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();
for (int n = 0; n < batch_size; ++n) {
correlation_backward_cuda_kernel_input1<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW, n);
}
correlation_backward_cuda_kernel_input1<scalar_t>
<<<blocks, threads, grad_cache_size,
at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);
for (int n = 0; n < batch_size; ++n) {
correlation_backward_cuda_kernel_input2<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW, n);
}
correlation_backward_cuda_kernel_input2<scalar_t>
<<<blocks, threads, grad_cache_size,
at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);
}));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment