Unverified Commit 745aa737 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Fix] Revise unit test of correlation (#1368)

* [Fix] Revise unit test of correlation

* rename

* lint

* lint

* lint

* lint
parent 9d4571e3
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/types.h> #include <torch/types.h>
#include <vector>
#include <iostream> #include <iostream>
#include <vector>
using namespace torch; using namespace torch;
...@@ -28,17 +29,10 @@ using namespace torch; ...@@ -28,17 +29,10 @@ using namespace torch;
#define THREADS_BACKWARD 16 #define THREADS_BACKWARD 16
template <typename scalar_t> template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1, __global__ void correlation_forward_cuda_kernel(
const TensorAcc4R rInput2, const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
TensorAcc5R output, int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
int kH, int kW, int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) {
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH,
int dilation_patchW,
int dH, int dW)
{
const int iH = rInput1.size(1); const int iH = rInput1.size(1);
const int iW = rInput1.size(2); const int iW = rInput1.size(2);
const int C = rInput1.size(3); const int C = rInput1.size(3);
...@@ -56,42 +50,35 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1, ...@@ -56,42 +50,35 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,
__shared__ scalar_t prod_sum[THREADS_FORWARD]; __shared__ scalar_t prod_sum[THREADS_FORWARD];
for (int ph = 0; ph < patchH; ++ph) for (int ph = 0; ph < patchH; ++ph) {
{
int ph_dilated = ph * dilation_patchH - patchRadH; int ph_dilated = ph * dilation_patchH - patchRadH;
for (int pw = 0; pw < patchW; ++pw) for (int pw = 0; pw < patchW; ++pw) {
{
int pw_dilated = pw * dilation_patchW - patchRadW; int pw_dilated = pw * dilation_patchW - patchRadW;
prod_sum[thread] = 0; prod_sum[thread] = 0;
for (int i = 0; i < kH; ++i) for (int i = 0; i < kH; ++i) {
{
int i1 = start_i + i * dilationH; int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated; int i2 = i1 + ph_dilated;
if WITHIN_BOUNDS (i1, i2, iH, iH) if
{ WITHIN_BOUNDS(i1, i2, iH, iH) {
for (int j = 0; j < kW; ++j) for (int j = 0; j < kW; ++j) {
{ int j1 = start_j + j * dilationW;
int j1 = start_j + j * dilationW; int j2 = j1 + pw_dilated;
int j2 = j1 + pw_dilated; if
if WITHIN_BOUNDS (j1, j2, iW, iW) WITHIN_BOUNDS(j1, j2, iW, iW) {
{ for (int c = thread; c < C; c += THREADS_FORWARD) {
for (int c = thread; c < C; c += THREADS_FORWARD) scalar_t v1 = rInput1[n][i1][j1][c];
{ scalar_t v2 = rInput2[n][i2][j2][c];
scalar_t v1 = rInput1[n][i1][j1][c]; prod_sum[thread] += v1 * v2;
scalar_t v2 = rInput2[n][i2][j2][c]; }
prod_sum[thread] += v1 * v2; }
}
} }
} }
}
} }
// accumulate // accumulate
__syncthreads(); __syncthreads();
if (thread == 0) if (thread == 0) {
{
scalar_t reduce_sum = 0; scalar_t reduce_sum = 0;
for (int index = 0; index < THREADS_FORWARD; ++index) for (int index = 0; index < THREADS_FORWARD; ++index) {
{
reduce_sum += prod_sum[index]; reduce_sum += prod_sum[index];
} }
output[n][ph][pw][h][w] = reduce_sum; output[n][ph][pw][h][w] = reduce_sum;
...@@ -101,18 +88,12 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1, ...@@ -101,18 +88,12 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_output, __global__ void correlation_backward_cuda_kernel_input1(
const TensorAcc4R input2, const TensorAcc5R grad_output, const TensorAcc4R input2,
TensorAcc4R grad_input1, TensorAcc4R grad_input1, const int kH, const int kW, const int patchH,
const int kH, const int kW, const int patchW, const int padH, const int padW, const int dilationH,
const int patchH, const int patchW, const int dilationW, const int dilation_patchH, const int dilation_patchW,
const int padH, const int padW, const int dH, const int dW, const int batch) {
const int dilationH, const int dilationW,
const int dilation_patchH, const int dilation_patchW,
const int dH, const int dW,
const int batch)
{
const int iH = input2.size(2); const int iH = input2.size(2);
const int iW = input2.size(3); const int iW = input2.size(3);
...@@ -137,29 +118,23 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o ...@@ -137,29 +118,23 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0; prod_sum[ph_off][pw_off] = 0;
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
{
int i1 = h + dilation_patchH * (ph - patchRadH); int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
{
int j1 = w + dilation_patchW * (pw - patchRadW); int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
{
scalar_t val = input2[n][c][i1][j1]; scalar_t val = input2[n][c][i1][j1];
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
{
int i2 = (h_3) / dH; int i2 = (h_3) / dH;
if (i2 * dH != h_3) if (i2 * dH != h_3) continue;
continue; for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW)
{
int j2 = (w_3) / dW; int j2 = (w_3) / dW;
if (j2 * dW != w_3) if (j2 * dW != w_3) continue;
continue; if
if WITHIN_BOUNDS (i2, j2, H, W) WITHIN_BOUNDS(i2, j2, H, W) {
{ prod_sum[ph_off][pw_off] +=
prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val; grad_output[n][ph][pw][i2][j2] * val;
} }
} }
} }
} }
...@@ -168,13 +143,10 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o ...@@ -168,13 +143,10 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o
__syncthreads(); __syncthreads();
if (ph_off == 0 && pw_off == 0) if (ph_off == 0 && pw_off == 0) {
{
scalar_t reduce_sum = 0; scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
{ for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw)
{
reduce_sum += prod_sum[ph][pw]; reduce_sum += prod_sum[ph][pw];
} }
} }
...@@ -183,17 +155,11 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o ...@@ -183,17 +155,11 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_output, __global__ void correlation_backward_cuda_kernel_input2(
const TensorAcc4R input1, const TensorAcc5R grad_output, const TensorAcc4R input1,
TensorAcc4R grad_input2, TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH,
int kH, int kW, int padW, int dilationH, int dilationW, int dilation_patchH,
int patchH, int patchW, int dilation_patchW, int dH, int dW, int batch) {
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW,
int batch)
{
const int iH = input1.size(2); const int iH = input1.size(2);
const int iW = input1.size(3); const int iW = input1.size(3);
...@@ -216,50 +182,42 @@ __global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_o ...@@ -216,50 +182,42 @@ __global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_o
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0; prod_sum[ph_off][pw_off] = 0;
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
{
int i1 = h - dilation_patchH * (ph - patchRadH); int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
{
int j1 = w - dilation_patchW * (pw - patchRadW); int j1 = w - dilation_patchW * (pw - patchRadW);
if WITHIN_BOUNDS (i1, j1, iH, iW) if
{ WITHIN_BOUNDS(i1, j1, iH, iW) {
scalar_t val = input1[n][c][i1][j1]; scalar_t val = input1[n][c][i1][j1];
const int h_2 = i1 + padH; const int h_2 = i1 + padH;
const int w_2 = j1 + padW; const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH; const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW; const int min_w = w_2 - dilatedKW;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
{ int i2 = (h_3) / dH;
int i2 = (h_3) / dH; if (i2 * dH != h_3) continue;
if (i2 * dH != h_3) for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
continue; int j2 = (w_3) / dW;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) if (j2 * dW != w_3) continue;
{ if
int j2 = (w_3) / dW; WITHIN_BOUNDS(i2, j2, H, W) {
if (j2 * dW != w_3) prod_sum[ph_off][pw_off] +=
continue; grad_output[n][ph][pw][i2][j2] * val;
if WITHIN_BOUNDS (i2, j2, H, W) }
{
prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val;
} }
} }
} }
}
} }
} }
__syncthreads(); __syncthreads();
if (ph_off == 0 && pw_off == 0) if (ph_off == 0 && pw_off == 0) {
{
scalar_t reduce_sum = 0; scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
{ for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw)
{
reduce_sum += prod_sum[ph][pw]; reduce_sum += prod_sum[ph][pw];
} }
} }
......
// Copyright (c) OpenMMLab. All rights reserved. // Copyright (c) OpenMMLab. All rights reserved.
#include <iostream> #include <iostream>
#include "pytorch_cpp_helper.hpp" #include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
Tensor output, int kH, int kW, Tensor output, int kH, int kW,
int patchH, int patchW, int patchH, int patchW, int padH,
int padH, int padW, int padW, int dilationH,
int dilationH, int dilationW, int dilationW, int dilation_patchH,
int dilation_patchH, int dilation_patchW, int dH, int dW);
int dilation_patchW,
int dH, int dW);
void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1, void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1,
Tensor input2, Tensor grad_input1, Tensor input2, Tensor grad_input1,
Tensor grad_input2, int kH, int kW, Tensor grad_input2, int kH, int kW,
int patchH, int patchW, int patchH, int patchW, int padH,
int padH, int padW, int padW, int dilationH,
int dilationH, int dilationW, int dilationW, int dilation_patchH,
int dilation_patchH, int dilation_patchW, int dH, int dW);
int dilation_patchW,
int dH, int dW);
void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output, void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output,
int kH, int kW, int patchH, int patchW, int kH, int kW, int patchH, int patchW, int padH,
int padH, int padW, int dilationH, int dilationW, int padW, int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW, int dilation_patchH, int dilation_patchW, int dH,
int dH, int dW) int dW) {
{ CorrelationForwardCUDAKernelLauncher(
input1, input2, output, kH, kW, patchH, patchW, padH, padW, dilationH,
CorrelationForwardCUDAKernelLauncher(input1, input2, output, kH, kW, dilationW, dilation_patchH, dilation_patchW, dH, dW);
patchH, patchW, padH, padW, dilationH,
dilationW, dilation_patchH,
dilation_patchW, dH, dW);
} }
void correlation_cuda_backward(Tensor grad_output, void correlation_cuda_backward(Tensor grad_output, Tensor input1, Tensor input2,
Tensor input1, Tensor input2, Tensor grad_input1, Tensor grad_input2, int kH,
Tensor grad_input1, Tensor grad_input2, int kW, int patchH, int patchW, int padH,
int kH, int kW, int patchH, int patchW, int padW, int dilationH, int dilationW,
int padH, int padW, int dilation_patchH, int dilation_patchW, int dH,
int dilationH, int dilationW, int dW) {
int dilation_patchH, int dilation_patchW, CorrelationBackwardCUDAKernelLauncher(
int dH, int dW) grad_output, input1, input2, grad_input1, grad_input2, kH, kW, patchH,
{ patchW, padH, padW, dilationH, dilationW, dilation_patchH,
CorrelationBackwardCUDAKernelLauncher(grad_output, input1, input2, dilation_patchW, dH, dW);
grad_input1, grad_input2, kH, kW,
patchH, patchW, padH, padW,
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW);
} }
#endif #endif
void correlation_forward(Tensor input1, Tensor input2, Tensor output, void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH,
int kH, int kW, int patchH, int patchW, int kW, int patchH, int patchW, int padH, int padW,
int padH, int padW, int dilationH, int dilationW, int dilation_patchH,
int dilationH, int dilationW, int dilation_patchW, int dH, int dW) {
int dilation_patchH, int dilation_patchW, if (input1.device().is_cuda() and input2.device().is_cuda()) {
int dH, int dW)
{
if (input1.device().is_cuda() and input2.device().is_cuda())
{
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input1); CHECK_CUDA_INPUT(input1);
CHECK_CUDA_INPUT(input2); CHECK_CUDA_INPUT(input2);
correlation_cuda_forward(input1, input2, output, kH, kW, correlation_cuda_forward(input1, input2, output, kH, kW, patchH, patchW,
patchH, patchW, padH, padW, padH, padW, dilationH, dilationW, dilation_patchH,
dilationH, dilationW, dilation_patchW, dH, dW);
dilation_patchH, dilation_patchW,
dH, dW);
#else #else
AT_ERROR("Correlation is not compiled with GPU support"); AT_ERROR("Correlation is not compiled with GPU support");
#endif #endif
} } else {
else AT_ERROR("Correlation is not implemented on CPU");
{ }
AT_ERROR("Correlation is not implemented on CPU");
}
} }
void correlation_backward(Tensor grad_output, void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2,
Tensor input1, Tensor input2, Tensor grad_input1, Tensor grad_input2, int kH,
Tensor grad_input1, Tensor grad_input2, int kW, int patchH, int patchW, int padH, int padW,
int kH, int kW, int dilationH, int dilationW, int dilation_patchH,
int patchH, int patchW, int dilation_patchW, int dH, int dW) {
int padH, int padW, if (input1.device().is_cuda() and input2.device().is_cuda()) {
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW)
{
if (input1.device().is_cuda() and input2.device().is_cuda())
{
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output); CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(input1); CHECK_CUDA_INPUT(input1);
CHECK_CUDA_INPUT(input2); CHECK_CUDA_INPUT(input2);
correlation_cuda_backward(grad_output, input1, input2, correlation_cuda_backward(grad_output, input1, input2, grad_input1,
grad_input1, grad_input2, kH, kW, grad_input2, kH, kW, patchH, patchW, padH, padW,
patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilationH, dilationW, dilation_patchW, dH, dW);
dilation_patchH, dilation_patchW,
dH, dW);
#else #else
AT_ERROR("Correlation is not compiled with GPU support"); AT_ERROR("Correlation is not compiled with GPU support");
#endif #endif
} } else {
else AT_ERROR("Correlation is not implemented on CPU");
{ }
AT_ERROR("Correlation is not implemented on CPU");
}
} }
...@@ -7,99 +7,87 @@ ...@@ -7,99 +7,87 @@
#include "pytorch_cuda_helper.hpp" #include "pytorch_cuda_helper.hpp"
void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
Tensor output, int kH, int kW, Tensor output, int kH, int kW,
int patchH, int patchW, int patchH, int patchW, int padH,
int padH, int padW, int padW, int dilationH,
int dilationH, int dilationW, int dilationW, int dilation_patchH,
int dilation_patchH, int dilation_patchW, int dH, int dW) {
int dilation_patchW, const int batch_size = input1.size(0);
int dH, int dW) const int iH = input1.size(2);
{ const int iW = input1.size(3);
const int dilatedKH = (kH - 1) * dilationH + 1;
const int batch_size = input1.size(0); const int dilatedKW = (kW - 1) * dilationW + 1;
const int iH = input1.size(2);
const int iW = input1.size(3); const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
const int dilatedKH = (kH - 1) * dilationH + 1; const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
const int dilatedKW = (kW - 1) * dilationW + 1;
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1; const int threads = THREADS_FORWARD;
const dim3 blocks(batch_size, oH, oW);
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous(); at::cuda::CUDAGuard device_guard(input1.device());
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
const int threads = THREADS_FORWARD; input1.scalar_type(), "correlation_forward_cuda", ([&] {
const dim3 blocks(batch_size, oH, oW); TensorAcc4R trInput1_acc =
trInput1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
at::cuda::CUDAGuard device_guard(input1.device()); TensorAcc4R trInput2_acc =
trInput2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), TensorAcc5R output_acc =
"correlation_forward_cuda", output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();
([&]{
TensorAcc4R trInput1_acc = trInput1.packed_accessor32<scalar_t,4,RestrictPtrTraits>(); correlation_forward_cuda_kernel<scalar_t>
TensorAcc4R trInput2_acc = trInput2.packed_accessor32<scalar_t,4,RestrictPtrTraits>(); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
TensorAcc5R output_acc = output.packed_accessor32<scalar_t,5,RestrictPtrTraits>(); trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW,
padH, padW, dilationH, dilationW, dilation_patchH,
correlation_forward_cuda_kernel<scalar_t><<<blocks, threads, 0, dilation_patchW, dH, dW);
at::cuda::getCurrentCUDAStream()>>>( }));
trInput1_acc, trInput2_acc, output_acc,
kH, kW, patchH, patchW, padH, padW, dilationH, dilationW,
dilation_patchH, dilation_patchW, dH, dW);
}));
} }
void CorrelationBackwardCUDAKernelLauncher(
void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1, Tensor grad_output, Tensor input1, Tensor input2, Tensor grad_input1,
Tensor input2, Tensor grad_input1, Tensor grad_input2, int kH, int kW, int patchH, int patchW, int padH,
Tensor grad_input2, int kH, int kW, int padW, int dilationH, int dilationW, int dilation_patchH,
int patchH, int patchW, int dilation_patchW, int dH, int dW) {
int padH, int padW, const int batch_size = input1.size(0);
int dilationH, int dilationW, const int iH = input1.size(2);
int dilation_patchH, const int iW = input1.size(3);
int dilation_patchW, const int C = input1.size(1);
int dH, int dW){
const int batch_size = input1.size(0); const dim3 blocks(C, iH, iW);
const int iH = input1.size(2); const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD);
const int iW = input1.size(3);
const int C = input1.size(1); at::cuda::CUDAGuard device_guard(input1.device());
const dim3 blocks(C, iH, iW); AT_DISPATCH_FLOATING_TYPES_AND_HALF(
const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD); input1.scalar_type(), "correlation_backward_cuda", ([&] {
TensorAcc4R input1_acc =
at::cuda::CUDAGuard device_guard(input1.device()); input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R input2_acc =
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
"correlation_backward_cuda", TensorAcc4R grad_input1_acc =
([&]{ grad_input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R input1_acc = input1.packed_accessor32<scalar_t,4,RestrictPtrTraits>(); TensorAcc4R grad_input2_acc =
TensorAcc4R input2_acc = input2.packed_accessor32<scalar_t,4,RestrictPtrTraits>(); grad_input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input1_acc = grad_input1.packed_accessor32<scalar_t,4,RestrictPtrTraits>(); TensorAcc5R grad_output_acc =
TensorAcc4R grad_input2_acc = grad_input2.packed_accessor32<scalar_t,4,RestrictPtrTraits>(); grad_output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();
TensorAcc5R grad_output_acc = grad_output.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
for (int n = 0; n < batch_size; ++n) {
for (int n = 0; n < batch_size; ++n){ correlation_backward_cuda_kernel_input1<scalar_t>
correlation_backward_cuda_kernel_input1<scalar_t><<<blocks, threads, 0, <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
at::cuda::getCurrentCUDAStream()>>>( grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH,
grad_output_acc, input2_acc, grad_input1_acc, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
kH, kW, patchH, patchW, padH, padW, dilation_patchW, dH, dW, n);
dilationH, dilationW, }
dilation_patchH, dilation_patchW,
dH, dW, n); for (int n = 0; n < batch_size; ++n) {
} correlation_backward_cuda_kernel_input2<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
for (int n = 0; n < batch_size; ++n){ grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH,
correlation_backward_cuda_kernel_input2<scalar_t><<<blocks, threads, 0, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
at::cuda::getCurrentCUDAStream()>>>( dilation_patchW, dH, dW, n);
grad_output_acc, input1_acc, grad_input2_acc, }
kH, kW, patchH, patchW, padH, padW, }));
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW, n);
}
}));
} }
...@@ -225,22 +225,16 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes, ...@@ -225,22 +225,16 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input, const Tensor &argmax_idx, Tensor grad_input,
const int pool_size); const int pool_size);
void correlation_forward(Tensor input1, Tensor input2, Tensor output, void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH,
int kH, int kW, int patchH, int patchW, int kW, int patchH, int patchW, int padH, int padW,
int padH, int padW, int dilationH, int dilationW, int dilation_patchH,
int dilationH, int dilationW, int dilation_patchW, int dH, int dW);
int dilation_patchH, int dilation_patchW,
int dH, int dW); void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2,
Tensor grad_input1, Tensor grad_input2, int kH,
void correlation_backward(Tensor grad_output, int kW, int patchH, int patchW, int padH, int padW,
Tensor input1, Tensor input2, int dilationH, int dilationW, int dilation_patchH,
Tensor grad_input1, Tensor grad_input2, int dilation_patchW, int dH, int dW);
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch import torch
from mmcv.ops import Correlation from mmcv.ops import Correlation
_input1 = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] _input1 = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
_input2 = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]] _input2 = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]]
_input2_2 = [[[[1., 2.], [3., 1.], [8., 5.]]]]
gt_out_shape = (1, 1, 1, 3, 3) gt_out_shape = (1, 1, 1, 3, 3)
_gt_out = [[[[[1., 4., 9.], [0., 1., 4.], [24., 25., 4.]]]]] _gt_out = [[[[[1., 4., 9.], [0., 1., 4.], [24., 25., 4.]]]]]
gt_input1_grad = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]] gt_input1_grad = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]]
_ap_gt_out = [[[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]],
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]],
[[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]],
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]],
[[[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]],
[[5., 10., 15.], [15., 5., 10.], [40., 25., 10.]],
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]]]]
def assert_equal_tensor(tensor_a, tensor_b): def assert_equal_tensor(tensor_a, tensor_b):
...@@ -43,6 +35,8 @@ class TestCorrelation: ...@@ -43,6 +35,8 @@ class TestCorrelation:
assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu()) assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu())
assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu()) assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu())
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_correlation(self): def test_correlation(self):
self._test_correlation(torch.float) self._test_correlation(torch.float)
self._test_correlation(torch.double) self._test_correlation(torch.double)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment