Commit b0646e0e authored by q.yao's avatar q.yao Committed by Zaida Zhou
Browse files

[Fix] Fix Correlation op (#2274)

* fix correlation

* fix lint
parent 6c89b717
...@@ -36,7 +36,8 @@ template <typename scalar_t> ...@@ -36,7 +36,8 @@ template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel( __global__ void correlation_forward_cuda_kernel(
const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output, const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH, int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) { int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW,
int oH, int oW) {
const int iH = rInput1.size(1); const int iH = rInput1.size(1);
const int iW = rInput1.size(2); const int iW = rInput1.size(2);
const int C = rInput1.size(3); const int C = rInput1.size(3);
...@@ -44,6 +45,9 @@ __global__ void correlation_forward_cuda_kernel( ...@@ -44,6 +45,9 @@ __global__ void correlation_forward_cuda_kernel(
const int n = blockIdx.x; const int n = blockIdx.x;
const int h = blockIdx.y * blockDim.y + threadIdx.y; const int h = blockIdx.y * blockDim.y + threadIdx.y;
const int w = blockIdx.z * blockDim.z + threadIdx.z; const int w = blockIdx.z * blockDim.z + threadIdx.z;
if (h >= oH || w >= oW) return;
const int thread = threadIdx.x; const int thread = threadIdx.x;
const int start_i = -padH + h * dH; const int start_i = -padH + h * dH;
......
...@@ -42,7 +42,7 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, ...@@ -42,7 +42,7 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW, trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW,
padH, padW, dilationH, dilationW, dilation_patchH, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW); dilation_patchW, dH, dW, oH, oW);
})); }));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment