"src/vscode:/vscode.git/clone" did not exist on "010bc4ea198eeaf379926316b214c22a1dab8d17"
Unverified Commit ff189047 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Fix Correlation op (#2274)

* fix correlation

* fix lint
parent 7fd7058a
...@@ -36,7 +36,8 @@ template <typename scalar_t> ...@@ -36,7 +36,8 @@ template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel( __global__ void correlation_forward_cuda_kernel(
const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output, const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH, int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) { int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW,
int oH, int oW) {
const int iH = rInput1.size(1); const int iH = rInput1.size(1);
const int iW = rInput1.size(2); const int iW = rInput1.size(2);
const int C = rInput1.size(3); const int C = rInput1.size(3);
...@@ -44,6 +45,9 @@ __global__ void correlation_forward_cuda_kernel( ...@@ -44,6 +45,9 @@ __global__ void correlation_forward_cuda_kernel(
const int n = blockIdx.x; const int n = blockIdx.x;
const int h = blockIdx.y * blockDim.y + threadIdx.y; const int h = blockIdx.y * blockDim.y + threadIdx.y;
const int w = blockIdx.z * blockDim.z + threadIdx.z; const int w = blockIdx.z * blockDim.z + threadIdx.z;
if (h >= oH || w >= oW) return;
const int thread = threadIdx.x; const int thread = threadIdx.x;
const int start_i = -padH + h * dH; const int start_i = -padH + h * dH;
...@@ -60,13 +64,11 @@ __global__ void correlation_forward_cuda_kernel( ...@@ -60,13 +64,11 @@ __global__ void correlation_forward_cuda_kernel(
for (int i = 0; i < kH; ++i) { for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH; int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated; int i2 = i1 + ph_dilated;
if if (WITHIN_BOUNDS(i1, i2, iH, iH)) {
WITHIN_BOUNDS(i1, i2, iH, iH) {
for (int j = 0; j < kW; ++j) { for (int j = 0; j < kW; ++j) {
int j1 = start_j + j * dilationW; int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated; int j2 = j1 + pw_dilated;
if if (WITHIN_BOUNDS(j1, j2, iW, iW)) {
WITHIN_BOUNDS(j1, j2, iW, iW) {
for (int c = thread; c < C; c += WARP_SIZE) { for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c]; scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c]; scalar_t v2 = rInput2[n][i2][j2][c];
......
...@@ -42,7 +42,7 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, ...@@ -42,7 +42,7 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW, trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW,
padH, padW, dilationH, dilationW, dilation_patchH, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW); dilation_patchW, dH, dW, oH, oW);
})); }));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment