Unverified Commit 85b37e7b authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Enhancment] Optimize correlation op (#1814)

* optimize forward

* fast backward

* fix bugs of grad input2
parent cff3fecc
...@@ -29,8 +29,8 @@ using namespace torch; ...@@ -29,8 +29,8 @@ using namespace torch;
#define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits> #define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits>
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W) #define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
#define THREADS_FORWARD 32 #define WARP_SIZE 32
#define THREADS_BACKWARD 16 #define FULL_MASK 0xffffffff
template <typename scalar_t> template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel( __global__ void correlation_forward_cuda_kernel(
...@@ -42,8 +42,8 @@ __global__ void correlation_forward_cuda_kernel( ...@@ -42,8 +42,8 @@ __global__ void correlation_forward_cuda_kernel(
const int C = rInput1.size(3); const int C = rInput1.size(3);
const int n = blockIdx.x; const int n = blockIdx.x;
const int h = blockIdx.y; const int h = blockIdx.y * blockDim.y + threadIdx.y;
const int w = blockIdx.z; const int w = blockIdx.z * blockDim.z + threadIdx.z;
const int thread = threadIdx.x; const int thread = threadIdx.x;
const int start_i = -padH + h * dH; const int start_i = -padH + h * dH;
...@@ -52,13 +52,11 @@ __global__ void correlation_forward_cuda_kernel( ...@@ -52,13 +52,11 @@ __global__ void correlation_forward_cuda_kernel(
const int patchRadH = dilation_patchH * (patchH - 1) / 2; const int patchRadH = dilation_patchH * (patchH - 1) / 2;
const int patchRadW = dilation_patchW * (patchW - 1) / 2; const int patchRadW = dilation_patchW * (patchW - 1) / 2;
__shared__ scalar_t prod_sum[THREADS_FORWARD];
for (int ph = 0; ph < patchH; ++ph) { for (int ph = 0; ph < patchH; ++ph) {
int ph_dilated = ph * dilation_patchH - patchRadH; int ph_dilated = ph * dilation_patchH - patchRadH;
for (int pw = 0; pw < patchW; ++pw) { for (int pw = 0; pw < patchW; ++pw) {
int pw_dilated = pw * dilation_patchW - patchRadW; int pw_dilated = pw * dilation_patchW - patchRadW;
prod_sum[thread] = 0; scalar_t prod_sum = 0.0f;
for (int i = 0; i < kH; ++i) { for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH; int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated; int i2 = i1 + ph_dilated;
...@@ -69,23 +67,20 @@ __global__ void correlation_forward_cuda_kernel( ...@@ -69,23 +67,20 @@ __global__ void correlation_forward_cuda_kernel(
int j2 = j1 + pw_dilated; int j2 = j1 + pw_dilated;
if if
WITHIN_BOUNDS(j1, j2, iW, iW) { WITHIN_BOUNDS(j1, j2, iW, iW) {
for (int c = thread; c < C; c += THREADS_FORWARD) { for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c]; scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c]; scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum[thread] += v1 * v2; prod_sum += v1 * v2;
} }
} }
} }
} }
} }
// accumulate // accumulate
__syncthreads(); for (int offset = 16; offset > 0; offset /= 2)
prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset);
if (thread == 0) { if (thread == 0) {
scalar_t reduce_sum = 0; output[n][ph][pw][h][w] = prod_sum;
for (int index = 0; index < THREADS_FORWARD; ++index) {
reduce_sum += prod_sum[index];
}
output[n][ph][pw][h][w] = reduce_sum;
} }
} }
} }
...@@ -97,9 +92,10 @@ __global__ void correlation_backward_cuda_kernel_input1( ...@@ -97,9 +92,10 @@ __global__ void correlation_backward_cuda_kernel_input1(
TensorAcc4R grad_input1, const int kH, const int kW, const int patchH, TensorAcc4R grad_input1, const int kH, const int kW, const int patchH,
const int patchW, const int padH, const int padW, const int dilationH, const int patchW, const int padH, const int padW, const int dilationH,
const int dilationW, const int dilation_patchH, const int dilation_patchW, const int dilationW, const int dilation_patchH, const int dilation_patchW,
const int dH, const int dW, const int batch) { const int dH, const int dW) {
const int iH = input2.size(2); const int iH = input2.size(1);
const int iW = input2.size(3); const int iW = input2.size(2);
const int C = input2.size(3);
const int H = grad_output.size(3); const int H = grad_output.size(3);
const int W = grad_output.size(4); const int W = grad_output.size(4);
...@@ -107,54 +103,53 @@ __global__ void correlation_backward_cuda_kernel_input1( ...@@ -107,54 +103,53 @@ __global__ void correlation_backward_cuda_kernel_input1(
const int patchRadH = (patchH - 1) / 2; const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2; const int patchRadW = (patchW - 1) / 2;
const int n = batch; const int n = blockIdx.x;
const int c = blockIdx.x;
const int h = blockIdx.y; const int h = blockIdx.y;
const int w = blockIdx.z; const int w = blockIdx.z;
const int ph_off = threadIdx.x;
const int pw_off = threadIdx.y;
const int h_2 = h + padH; const int h_2 = h + padH;
const int w_2 = w + padW; const int w_2 = w + padW;
const int min_h = h_2 - kH * dilationH; const int min_h = h_2 - kH * dilationH;
const int min_w = w_2 - kW * dilationW; const int min_w = w_2 - kW * dilationW;
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
prod_sum[ph_off][pw_off] = 0; scalar_t *grad_cache = reinterpret_cast<scalar_t *>(grad_cache_char);
for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) { const int ph = i / patchW;
const int pw = i % patchW;
int i1 = h + dilation_patchH * (ph - patchRadH); int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) { int j1 = w + dilation_patchW * (pw - patchRadW);
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) { if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t val = input2[n][c][i1][j1]; scalar_t grad_val = 0.0f;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) { for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH; int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue; if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) { for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW; int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue; if (j2 * dW != w_3) continue;
if if (WITHIN_BOUNDS(i2, j2, H, W)) {
WITHIN_BOUNDS(i2, j2, H, W) { grad_val += grad_output[n][ph][pw][i2][j2];
prod_sum[ph_off][pw_off] +=
grad_output[n][ph][pw][i2][j2] * val;
}
} }
} }
} }
grad_cache[i] = grad_val;
} }
} }
__syncthreads(); __syncthreads();
if (ph_off == 0 && pw_off == 0) { for (int c = threadIdx.x; c < C; c += blockDim.x) {
scalar_t reduce_sum = 0; scalar_t grad_input_val = 0.0f;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) { for (int ph = 0; ph < patchH; ++ph) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) { int i1 = h + dilation_patchH * (ph - patchRadH);
reduce_sum += prod_sum[ph][pw]; for (int pw = 0; pw < patchW; ++pw) {
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
grad_input_val += input2[n][i1][j1][c] * grad_cache[ph * patchW + pw];
}
} }
} }
grad_input1[n][c][h][w] = reduce_sum; grad_input1[n][c][h][w] = grad_input_val;
} }
} }
...@@ -163,9 +158,10 @@ __global__ void correlation_backward_cuda_kernel_input2( ...@@ -163,9 +158,10 @@ __global__ void correlation_backward_cuda_kernel_input2(
const TensorAcc5R grad_output, const TensorAcc4R input1, const TensorAcc5R grad_output, const TensorAcc4R input1,
TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH, TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW, int dilation_patchH, int padW, int dilationH, int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW, int batch) { int dilation_patchW, int dH, int dW) {
const int iH = input1.size(2); const int iH = input1.size(1);
const int iW = input1.size(3); const int iW = input1.size(2);
const int C = input1.size(3);
const int patchRadH = (patchH - 1) / 2; const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2; const int patchRadW = (patchW - 1) / 2;
...@@ -176,56 +172,54 @@ __global__ void correlation_backward_cuda_kernel_input2( ...@@ -176,56 +172,54 @@ __global__ void correlation_backward_cuda_kernel_input2(
const int dilatedKH = kH * dilationH; const int dilatedKH = kH * dilationH;
const int dilatedKW = kW * dilationW; const int dilatedKW = kW * dilationW;
const int n = batch; const int n = blockIdx.x;
const int c = blockIdx.x;
const int h = blockIdx.y; const int h = blockIdx.y;
const int w = blockIdx.z; const int w = blockIdx.z;
const int ph_off = threadIdx.x;
const int pw_off = threadIdx.y;
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) { extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
scalar_t *grad_cache = reinterpret_cast<scalar_t *>(grad_cache_char);
for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
const int ph = i / patchW;
const int pw = i % patchW;
int i1 = h - dilation_patchH * (ph - patchRadH); int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) { int j1 = w - dilation_patchW * (pw - patchRadW);
int j1 = w - dilation_patchW * (pw - patchRadW);
if if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
WITHIN_BOUNDS(i1, j1, iH, iW) { scalar_t grad_val = 0.0f;
scalar_t val = input1[n][c][i1][j1];
const int h_2 = i1 + padH;
const int h_2 = i1 + padH; const int w_2 = j1 + padW;
const int w_2 = j1 + padW; const int min_h = h_2 - dilatedKH;
const int min_h = h_2 - dilatedKH; const int min_w = w_2 - dilatedKW;
const int min_w = w_2 - dilatedKW;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) { int i2 = (h_3) / dH;
int i2 = (h_3) / dH; if (i2 * dH != h_3) continue;
if (i2 * dH != h_3) continue; for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) { int j2 = (w_3) / dW;
int j2 = (w_3) / dW; if (j2 * dW != w_3) continue;
if (j2 * dW != w_3) continue; if (WITHIN_BOUNDS(i2, j2, H, W)) {
if grad_val += grad_output[n][ph][pw][i2][j2];
WITHIN_BOUNDS(i2, j2, H, W) {
prod_sum[ph_off][pw_off] +=
grad_output[n][ph][pw][i2][j2] * val;
}
}
} }
} }
}
grad_cache[i] = grad_val;
} }
} }
__syncthreads(); __syncthreads();
if (ph_off == 0 && pw_off == 0) { for (int c = threadIdx.x; c < C; c += blockDim.x) {
scalar_t reduce_sum = 0; scalar_t grad_input_val = 0.0f;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) { for (int ph = 0; ph < patchH; ++ph) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) { int i1 = h - dilation_patchH * (ph - patchRadH);
reduce_sum += prod_sum[ph][pw]; for (int pw = 0; pw < patchW; ++pw) {
int j1 = w - dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
grad_input_val += input1[n][i1][j1][c] * grad_cache[ph * patchW + pw];
}
} }
} }
grad_input2[n][c][h][w] = reduce_sum; grad_input2[n][c][h][w] = grad_input_val;
} }
} }
#endif #endif
...@@ -24,8 +24,8 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, ...@@ -24,8 +24,8 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous(); auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous(); auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const int threads = THREADS_FORWARD; const dim3 threads(WARP_SIZE, 4, 4);
const dim3 blocks(batch_size, oH, oW); const dim3 blocks(batch_size, (oH + 3) >> 2, (oW + 3) >> 2);
at::cuda::CUDAGuard device_guard(input1.device()); at::cuda::CUDAGuard device_guard(input1.device());
...@@ -56,17 +56,20 @@ void CorrelationBackwardCUDAKernelLauncher( ...@@ -56,17 +56,20 @@ void CorrelationBackwardCUDAKernelLauncher(
const int iW = input1.size(3); const int iW = input1.size(3);
const int C = input1.size(1); const int C = input1.size(1);
const dim3 blocks(C, iH, iW); auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD); auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const dim3 blocks(batch_size, iH, iW);
const dim3 threads(THREADS_PER_BLOCK);
at::cuda::CUDAGuard device_guard(input1.device()); at::cuda::CUDAGuard device_guard(input1.device());
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input1.scalar_type(), "correlation_backward_cuda", ([&] { input1.scalar_type(), "correlation_backward_cuda", ([&] {
const int grad_cache_size = patchH * patchW * sizeof(scalar_t);
TensorAcc4R input1_acc = TensorAcc4R input1_acc =
input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>(); trInput1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R input2_acc = TensorAcc4R input2_acc =
input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>(); trInput2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input1_acc = TensorAcc4R grad_input1_acc =
grad_input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>(); grad_input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input2_acc = TensorAcc4R grad_input2_acc =
...@@ -74,20 +77,18 @@ void CorrelationBackwardCUDAKernelLauncher( ...@@ -74,20 +77,18 @@ void CorrelationBackwardCUDAKernelLauncher(
TensorAcc5R grad_output_acc = TensorAcc5R grad_output_acc =
grad_output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>(); grad_output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();
for (int n = 0; n < batch_size; ++n) { correlation_backward_cuda_kernel_input1<scalar_t>
correlation_backward_cuda_kernel_input1<scalar_t> <<<blocks, threads, grad_cache_size,
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH, grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW, n); dilation_patchW, dH, dW);
}
for (int n = 0; n < batch_size; ++n) { correlation_backward_cuda_kernel_input2<scalar_t>
correlation_backward_cuda_kernel_input2<scalar_t> <<<blocks, threads, grad_cache_size,
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH, grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW, n); dilation_patchW, dH, dW);
}
})); }));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment