Unverified Commit ff189047 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Fix Correlation op (#2274)

* fix correlation

* fix lint
parent 7fd7058a
......@@ -36,7 +36,8 @@ template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel(
const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) {
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW,
int oH, int oW) {
const int iH = rInput1.size(1);
const int iW = rInput1.size(2);
const int C = rInput1.size(3);
......@@ -44,6 +45,9 @@ __global__ void correlation_forward_cuda_kernel(
const int n = blockIdx.x;
const int h = blockIdx.y * blockDim.y + threadIdx.y;
const int w = blockIdx.z * blockDim.z + threadIdx.z;
if (h >= oH || w >= oW) return;
const int thread = threadIdx.x;
const int start_i = -padH + h * dH;
......@@ -60,21 +64,19 @@ __global__ void correlation_forward_cuda_kernel(
for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated;
if
WITHIN_BOUNDS(i1, i2, iH, iH) {
for (int j = 0; j < kW; ++j) {
int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated;
if
WITHIN_BOUNDS(j1, j2, iW, iW) {
for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum += v1 * v2;
}
}
if (WITHIN_BOUNDS(i1, i2, iH, iH)) {
for (int j = 0; j < kW; ++j) {
int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated;
if (WITHIN_BOUNDS(j1, j2, iW, iW)) {
for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum += v1 * v2;
}
}
}
}
}
// accumulate
for (int offset = 16; offset > 0; offset /= 2)
......
......@@ -42,7 +42,7 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW,
padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);
dilation_patchW, dH, dW, oH, oW);
}));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment