Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
b92ea0b5
Unverified
Commit
b92ea0b5
authored
Sep 23, 2021
by
Miao Zheng
Committed by
GitHub
Sep 23, 2021
Browse files
[Feature] Add Correlation CUDA op (#1361)
parent
f3dfc413
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
736 additions
and
1 deletion
+736
-1
docs/understand_mmcv/ops.md
docs/understand_mmcv/ops.md
+1
-0
docs_zh_CN/understand_mmcv/ops.md
docs_zh_CN/understand_mmcv/ops.md
+1
-0
mmcv/ops/__init__.py
mmcv/ops/__init__.py
+3
-1
mmcv/ops/correlation.py
mmcv/ops/correlation.py
+173
-0
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
+269
-0
mmcv/ops/csrc/pytorch/correlation.cpp
mmcv/ops/csrc/pytorch/correlation.cpp
+116
-0
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
+105
-0
mmcv/ops/csrc/pytorch/pybind.cpp
mmcv/ops/csrc/pytorch/pybind.cpp
+19
-0
tests/test_ops/test_corr.py
tests/test_ops/test_corr.py
+49
-0
No files found.
docs/understand_mmcv/ops.md
View file @
b92ea0b5
...
@@ -22,3 +22,4 @@ We implement common CUDA ops used in detection, segmentation, etc.
...
@@ -22,3 +22,4 @@ We implement common CUDA ops used in detection, segmentation, etc.
-
SoftNMS
-
SoftNMS
-
Synchronized BatchNorm
-
Synchronized BatchNorm
-
Weight standardization
-
Weight standardization
-
Correlation
docs_zh_CN/understand_mmcv/ops.md
View file @
b92ea0b5
...
@@ -22,3 +22,4 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
...
@@ -22,3 +22,4 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
-
SoftNMS
-
SoftNMS
-
Synchronized BatchNorm
-
Synchronized BatchNorm
-
Weight standardization
-
Weight standardization
-
Correlation
mmcv/ops/__init__.py
View file @
b92ea0b5
...
@@ -7,6 +7,7 @@ from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
...
@@ -7,6 +7,7 @@ from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from
.cc_attention
import
CrissCrossAttention
from
.cc_attention
import
CrissCrossAttention
from
.contour_expand
import
contour_expand
from
.contour_expand
import
contour_expand
from
.corner_pool
import
CornerPool
from
.corner_pool
import
CornerPool
from
.correlation
import
Correlation
from
.deform_conv
import
DeformConv2d
,
DeformConv2dPack
,
deform_conv2d
from
.deform_conv
import
DeformConv2d
,
DeformConv2dPack
,
deform_conv2d
from
.deform_roi_pool
import
(
DeformRoIPool
,
DeformRoIPoolPack
,
from
.deform_roi_pool
import
(
DeformRoIPool
,
DeformRoIPoolPack
,
ModulatedDeformRoIPoolPack
,
deform_roi_pool
)
ModulatedDeformRoIPoolPack
,
deform_roi_pool
)
...
@@ -53,5 +54,6 @@ __all__ = [
...
@@ -53,5 +54,6 @@ __all__ = [
'SAConv2d'
,
'TINShift'
,
'tin_shift'
,
'box_iou_rotated'
,
'nms_rotated'
,
'SAConv2d'
,
'TINShift'
,
'tin_shift'
,
'box_iou_rotated'
,
'nms_rotated'
,
'ball_query'
,
'upfirdn2d'
,
'FusedBiasLeakyReLU'
,
'fused_bias_leakyrelu'
,
'ball_query'
,
'upfirdn2d'
,
'FusedBiasLeakyReLU'
,
'fused_bias_leakyrelu'
,
'RoIAlignRotated'
,
'roi_align_rotated'
,
'pixel_group'
,
'contour_expand'
,
'RoIAlignRotated'
,
'roi_align_rotated'
,
'pixel_group'
,
'contour_expand'
,
'MultiScaleDeformableAttention'
,
'BorderAlign'
,
'border_align'
'MultiScaleDeformableAttention'
,
'BorderAlign'
,
'border_align'
,
'Correlation'
]
]
mmcv/ops/correlation.py
0 → 100644
View file @
b92ea0b5
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch
import
Tensor
,
nn
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.nn.modules.utils
import
_pair
from
..utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'correlation_forward'
,
'correlation_backward'
])
class
CorrelationFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input1
,
input2
,
kernel_size
=
1
,
max_displacement
=
1
,
stride
=
1
,
padding
=
1
,
dilation
=
1
,
dilation_patch
=
1
):
ctx
.
save_for_backward
(
input1
,
input2
)
kH
,
kW
=
ctx
.
kernel_size
=
_pair
(
kernel_size
)
patch_size
=
max_displacement
*
2
+
1
ctx
.
patch_size
=
patch_size
dH
,
dW
=
ctx
.
stride
=
_pair
(
stride
)
padH
,
padW
=
ctx
.
padding
=
_pair
(
padding
)
dilationH
,
dilationW
=
ctx
.
dilation
=
_pair
(
dilation
)
dilation_patchH
,
dilation_patchW
=
ctx
.
dilation_patch
=
_pair
(
dilation_patch
)
output_size
=
CorrelationFunction
.
_output_size
(
ctx
,
input1
)
output
=
input1
.
new_zeros
(
output_size
)
ext_module
.
correlation_forward
(
input1
,
input2
,
output
,
kH
,
kW
,
patch_size
,
patch_size
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
)
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
input1
,
input2
=
ctx
.
saved_tensors
kH
,
kW
=
ctx
.
kernel_size
patch_size
=
ctx
.
patch_size
padH
,
padW
=
ctx
.
padding
dilationH
,
dilationW
=
ctx
.
dilation
dilation_patchH
,
dilation_patchW
=
ctx
.
dilation_patch
dH
,
dW
=
ctx
.
stride
grad_input1
=
torch
.
zeros_like
(
input1
)
grad_input2
=
torch
.
zeros_like
(
input2
)
ext_module
.
correlation_backward
(
grad_output
,
input1
,
input2
,
grad_input1
,
grad_input2
,
kH
,
kW
,
patch_size
,
patch_size
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
)
return
grad_input1
,
grad_input2
,
None
,
None
,
None
,
None
,
None
,
None
@
staticmethod
def
_output_size
(
ctx
,
input1
):
iH
,
iW
=
input1
.
size
(
2
),
input1
.
size
(
3
)
batch_size
=
input1
.
size
(
0
)
kH
,
kW
=
ctx
.
kernel_size
patch_size
=
ctx
.
patch_size
dH
,
dW
=
ctx
.
stride
padH
,
padW
=
ctx
.
padding
dilationH
,
dilationW
=
ctx
.
dilation
dilatedKH
=
(
kH
-
1
)
*
dilationH
+
1
dilatedKW
=
(
kW
-
1
)
*
dilationW
+
1
oH
=
int
((
iH
+
2
*
padH
-
dilatedKH
)
/
dH
+
1
)
oW
=
int
((
iW
+
2
*
padW
-
dilatedKW
)
/
dW
+
1
)
output_size
=
(
batch_size
,
patch_size
,
patch_size
,
oH
,
oW
)
return
output_size
class
Correlation
(
nn
.
Module
):
r
"""Correlation operator
This correlation operator works for optical flow correlation computation.
There are two batched tensors with shape :math:`(N, C, H, W)`,
and the correlation output's shape is
:math:`(N, \text{max_displacement} \times 2+1,
\text{max_displacement} \times 2+1,
H_{out}, W_{out})`
where
.. math::
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding} -
\text{dilation} \times (\text{kernel_size} - 1) - 1}
{\text{stride}} + 1\right\rfloor
.. math::
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding} -
\text{dilation} \times (\text{kernel_size} - 1) - 1}
{\text{stride}} + 1\right\rfloor
the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding
window convolution between input1 and shifted input2,
.. math::
Corr(N_i, dx, dy) =
\sum_{c=0}^{C-1}
input1(N_i, c) \star
\mathcal{S}(input2(N_i, c), dy, dx)
where :math:`\star` is the valid 2d sliding window convolution operator,
and :math:`\mathcal{S}` means shifting the input features (auto-complete
zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
[-\text{max_displacement} \times \text{dilation_patch},
\text{max_displacement} \times \text{dilation_patch}]`.
Args:
kernel_size (int): The size of sliding window i.e. local neighborhood
representing the center points and involved in correlation
computation. Defaults to 1.
max_displacement (int): The radius for computing correlation volume,
but the actual working space can be dilated by dilation_patch.
Defaults to 1.
stride (int): The stride of the sliding blocks in the input spatial
dimensions. Defaults to 1.
padding (int): Zero padding added to all four sides of the input1.
Defaults to 0.
dilation (int): The spacing of local neighborhood that will involved
in correlation. Defaults to 1.
dilation_patch (int): The spacing between position need to compute
correlation. Defaults to 1.
"""
def
__init__
(
self
,
kernel_size
:
int
=
1
,
max_displacement
:
int
=
1
,
stride
:
int
=
1
,
padding
:
int
=
0
,
dilation
:
int
=
1
,
dilation_patch
:
int
=
1
)
->
None
:
super
().
__init__
()
self
.
kernel_size
=
kernel_size
self
.
max_displacement
=
max_displacement
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
dilation_patch
=
dilation_patch
def
forward
(
self
,
input1
:
Tensor
,
input2
:
Tensor
)
->
Tensor
:
return
CorrelationFunction
.
apply
(
input1
,
input2
,
self
.
kernel_size
,
self
.
max_displacement
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
dilation_patch
)
def
__repr__
(
self
)
->
str
:
s
=
self
.
__class__
.
__name__
s
+=
f
'(kernel_size=
{
self
.
kernel_size
}
, '
s
+=
f
'max_displacement=
{
self
.
max_displacement
}
, '
s
+=
f
'stride=
{
self
.
stride
}
, '
s
+=
f
'padding=
{
self
.
padding
}
, '
s
+=
f
'dilation=
{
self
.
dilation
}
, '
s
+=
f
'dilation_patch=
{
self
.
dilation_patch
}
)'
return
s
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
0 → 100644
View file @
b92ea0b5
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu
// Original licence: Under MIT License
#ifndef CORRELATION_CUDA
#define CORRELATION_CUDA
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <vector>
#include <iostream>
using
namespace
torch
;
#define TensorAcc4R PackedTensorAccessor32<scalar_t, 4, RestrictPtrTraits>
#define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits>
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
#define THREADS_FORWARD 32
#define THREADS_BACKWARD 16
template
<
typename
scalar_t
>
__global__
void
correlation_forward_cuda_kernel
(
const
TensorAcc4R
rInput1
,
const
TensorAcc4R
rInput2
,
TensorAcc5R
output
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
)
{
const
int
iH
=
rInput1
.
size
(
1
);
const
int
iW
=
rInput1
.
size
(
2
);
const
int
C
=
rInput1
.
size
(
3
);
const
int
n
=
blockIdx
.
x
;
const
int
h
=
blockIdx
.
y
;
const
int
w
=
blockIdx
.
z
;
const
int
thread
=
threadIdx
.
x
;
const
int
start_i
=
-
padH
+
h
*
dH
;
const
int
start_j
=
-
padW
+
w
*
dW
;
const
int
patchRadH
=
dilation_patchH
*
(
patchH
-
1
)
/
2
;
const
int
patchRadW
=
dilation_patchW
*
(
patchW
-
1
)
/
2
;
__shared__
scalar_t
prod_sum
[
THREADS_FORWARD
];
for
(
int
ph
=
0
;
ph
<
patchH
;
++
ph
)
{
int
ph_dilated
=
ph
*
dilation_patchH
-
patchRadH
;
for
(
int
pw
=
0
;
pw
<
patchW
;
++
pw
)
{
int
pw_dilated
=
pw
*
dilation_patchW
-
patchRadW
;
prod_sum
[
thread
]
=
0
;
for
(
int
i
=
0
;
i
<
kH
;
++
i
)
{
int
i1
=
start_i
+
i
*
dilationH
;
int
i2
=
i1
+
ph_dilated
;
if
WITHIN_BOUNDS
(
i1
,
i2
,
iH
,
iH
)
{
for
(
int
j
=
0
;
j
<
kW
;
++
j
)
{
int
j1
=
start_j
+
j
*
dilationW
;
int
j2
=
j1
+
pw_dilated
;
if
WITHIN_BOUNDS
(
j1
,
j2
,
iW
,
iW
)
{
for
(
int
c
=
thread
;
c
<
C
;
c
+=
THREADS_FORWARD
)
{
scalar_t
v1
=
rInput1
[
n
][
i1
][
j1
][
c
];
scalar_t
v2
=
rInput2
[
n
][
i2
][
j2
][
c
];
prod_sum
[
thread
]
+=
v1
*
v2
;
}
}
}
}
}
// accumulate
__syncthreads
();
if
(
thread
==
0
)
{
scalar_t
reduce_sum
=
0
;
for
(
int
index
=
0
;
index
<
THREADS_FORWARD
;
++
index
)
{
reduce_sum
+=
prod_sum
[
index
];
}
output
[
n
][
ph
][
pw
][
h
][
w
]
=
reduce_sum
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
correlation_backward_cuda_kernel_input1
(
const
TensorAcc5R
grad_output
,
const
TensorAcc4R
input2
,
TensorAcc4R
grad_input1
,
const
int
kH
,
const
int
kW
,
const
int
patchH
,
const
int
patchW
,
const
int
padH
,
const
int
padW
,
const
int
dilationH
,
const
int
dilationW
,
const
int
dilation_patchH
,
const
int
dilation_patchW
,
const
int
dH
,
const
int
dW
,
const
int
batch
)
{
const
int
iH
=
input2
.
size
(
2
);
const
int
iW
=
input2
.
size
(
3
);
const
int
H
=
grad_output
.
size
(
3
);
const
int
W
=
grad_output
.
size
(
4
);
const
int
patchRadH
=
(
patchH
-
1
)
/
2
;
const
int
patchRadW
=
(
patchW
-
1
)
/
2
;
const
int
n
=
batch
;
const
int
c
=
blockIdx
.
x
;
const
int
h
=
blockIdx
.
y
;
const
int
w
=
blockIdx
.
z
;
const
int
ph_off
=
threadIdx
.
x
;
const
int
pw_off
=
threadIdx
.
y
;
const
int
h_2
=
h
+
padH
;
const
int
w_2
=
w
+
padW
;
const
int
min_h
=
h_2
-
kH
*
dilationH
;
const
int
min_w
=
w_2
-
kW
*
dilationW
;
__shared__
scalar_t
prod_sum
[
THREADS_BACKWARD
][
THREADS_BACKWARD
];
prod_sum
[
ph_off
][
pw_off
]
=
0
;
for
(
int
ph
=
ph_off
;
ph
<
patchH
;
ph
+=
THREADS_BACKWARD
)
{
int
i1
=
h
+
dilation_patchH
*
(
ph
-
patchRadH
);
for
(
int
pw
=
pw_off
;
pw
<
patchW
;
pw
+=
THREADS_BACKWARD
)
{
int
j1
=
w
+
dilation_patchW
*
(
pw
-
patchRadW
);
if
(
WITHIN_BOUNDS
(
i1
,
j1
,
iH
,
iW
))
{
scalar_t
val
=
input2
[
n
][
c
][
i1
][
j1
];
for
(
int
h_3
=
h_2
;
h_3
>
min_h
;
h_3
-=
dilationH
)
{
int
i2
=
(
h_3
)
/
dH
;
if
(
i2
*
dH
!=
h_3
)
continue
;
for
(
int
w_3
=
w_2
;
w_3
>
min_w
;
w_3
-=
dilationW
)
{
int
j2
=
(
w_3
)
/
dW
;
if
(
j2
*
dW
!=
w_3
)
continue
;
if
WITHIN_BOUNDS
(
i2
,
j2
,
H
,
W
)
{
prod_sum
[
ph_off
][
pw_off
]
+=
grad_output
[
n
][
ph
][
pw
][
i2
][
j2
]
*
val
;
}
}
}
}
}
}
__syncthreads
();
if
(
ph_off
==
0
&&
pw_off
==
0
)
{
scalar_t
reduce_sum
=
0
;
for
(
int
ph
=
0
;
ph
<
THREADS_BACKWARD
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
THREADS_BACKWARD
;
++
pw
)
{
reduce_sum
+=
prod_sum
[
ph
][
pw
];
}
}
grad_input1
[
n
][
c
][
h
][
w
]
=
reduce_sum
;
}
}
template
<
typename
scalar_t
>
__global__
void
correlation_backward_cuda_kernel_input2
(
const
TensorAcc5R
grad_output
,
const
TensorAcc4R
input1
,
TensorAcc4R
grad_input2
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
,
int
batch
)
{
const
int
iH
=
input1
.
size
(
2
);
const
int
iW
=
input1
.
size
(
3
);
const
int
patchRadH
=
(
patchH
-
1
)
/
2
;
const
int
patchRadW
=
(
patchW
-
1
)
/
2
;
const
int
H
=
grad_output
.
size
(
3
);
const
int
W
=
grad_output
.
size
(
4
);
const
int
dilatedKH
=
kH
*
dilationH
;
const
int
dilatedKW
=
kW
*
dilationW
;
const
int
n
=
batch
;
const
int
c
=
blockIdx
.
x
;
const
int
h
=
blockIdx
.
y
;
const
int
w
=
blockIdx
.
z
;
const
int
ph_off
=
threadIdx
.
x
;
const
int
pw_off
=
threadIdx
.
y
;
__shared__
scalar_t
prod_sum
[
THREADS_BACKWARD
][
THREADS_BACKWARD
];
prod_sum
[
ph_off
][
pw_off
]
=
0
;
for
(
int
ph
=
ph_off
;
ph
<
patchH
;
ph
+=
THREADS_BACKWARD
)
{
int
i1
=
h
-
dilation_patchH
*
(
ph
-
patchRadH
);
for
(
int
pw
=
pw_off
;
pw
<
patchW
;
pw
+=
THREADS_BACKWARD
)
{
int
j1
=
w
-
dilation_patchW
*
(
pw
-
patchRadW
);
if
WITHIN_BOUNDS
(
i1
,
j1
,
iH
,
iW
)
{
scalar_t
val
=
input1
[
n
][
c
][
i1
][
j1
];
const
int
h_2
=
i1
+
padH
;
const
int
w_2
=
j1
+
padW
;
const
int
min_h
=
h_2
-
dilatedKH
;
const
int
min_w
=
w_2
-
dilatedKW
;
for
(
int
h_3
=
h_2
;
h_3
>
min_h
;
h_3
-=
dilationH
)
{
int
i2
=
(
h_3
)
/
dH
;
if
(
i2
*
dH
!=
h_3
)
continue
;
for
(
int
w_3
=
w_2
;
w_3
>
min_w
;
w_3
-=
dilationW
)
{
int
j2
=
(
w_3
)
/
dW
;
if
(
j2
*
dW
!=
w_3
)
continue
;
if
WITHIN_BOUNDS
(
i2
,
j2
,
H
,
W
)
{
prod_sum
[
ph_off
][
pw_off
]
+=
grad_output
[
n
][
ph
][
pw
][
i2
][
j2
]
*
val
;
}
}
}
}
}
}
__syncthreads
();
if
(
ph_off
==
0
&&
pw_off
==
0
)
{
scalar_t
reduce_sum
=
0
;
for
(
int
ph
=
0
;
ph
<
THREADS_BACKWARD
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
THREADS_BACKWARD
;
++
pw
)
{
reduce_sum
+=
prod_sum
[
ph
][
pw
];
}
}
grad_input2
[
n
][
c
][
h
][
w
]
=
reduce_sum
;
}
}
#endif
mmcv/ops/csrc/pytorch/correlation.cpp
0 → 100644
View file @
b92ea0b5
// Copyright (c) OpenMMLab. All rights reserved.
#include <iostream>
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void
CorrelationForwardCUDAKernelLauncher
(
Tensor
input1
,
Tensor
input2
,
Tensor
output
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
);
void
CorrelationBackwardCUDAKernelLauncher
(
Tensor
grad_output
,
Tensor
input1
,
Tensor
input2
,
Tensor
grad_input1
,
Tensor
grad_input2
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
);
void
correlation_cuda_forward
(
Tensor
input1
,
Tensor
input2
,
Tensor
output
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
)
{
CorrelationForwardCUDAKernelLauncher
(
input1
,
input2
,
output
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
);
}
void
correlation_cuda_backward
(
Tensor
grad_output
,
Tensor
input1
,
Tensor
input2
,
Tensor
grad_input1
,
Tensor
grad_input2
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
)
{
CorrelationBackwardCUDAKernelLauncher
(
grad_output
,
input1
,
input2
,
grad_input1
,
grad_input2
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
);
}
#endif
void
correlation_forward
(
Tensor
input1
,
Tensor
input2
,
Tensor
output
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
)
{
if
(
input1
.
device
().
is_cuda
()
and
input2
.
device
().
is_cuda
())
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT
(
input1
);
CHECK_CUDA_INPUT
(
input2
);
correlation_cuda_forward
(
input1
,
input2
,
output
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
);
#else
AT_ERROR
(
"Correlation is not compiled with GPU support"
);
#endif
}
else
{
AT_ERROR
(
"Correlation is not implemented on CPU"
);
}
}
void
correlation_backward
(
Tensor
grad_output
,
Tensor
input1
,
Tensor
input2
,
Tensor
grad_input1
,
Tensor
grad_input2
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
)
{
if
(
input1
.
device
().
is_cuda
()
and
input2
.
device
().
is_cuda
())
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT
(
grad_output
);
CHECK_CUDA_INPUT
(
input1
);
CHECK_CUDA_INPUT
(
input2
);
correlation_cuda_backward
(
grad_output
,
input1
,
input2
,
grad_input1
,
grad_input2
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
);
#else
AT_ERROR
(
"Correlation is not compiled with GPU support"
);
#endif
}
else
{
AT_ERROR
(
"Correlation is not implemented on CPU"
);
}
}
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
0 → 100644
View file @
b92ea0b5
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu
// Original licence: Under MIT License
#include "correlation_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
void
CorrelationForwardCUDAKernelLauncher
(
Tensor
input1
,
Tensor
input2
,
Tensor
output
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
)
{
const
int
batch_size
=
input1
.
size
(
0
);
const
int
iH
=
input1
.
size
(
2
);
const
int
iW
=
input1
.
size
(
3
);
const
int
dilatedKH
=
(
kH
-
1
)
*
dilationH
+
1
;
const
int
dilatedKW
=
(
kW
-
1
)
*
dilationW
+
1
;
const
auto
oH
=
(
iH
+
2
*
padH
-
dilatedKH
)
/
dH
+
1
;
const
auto
oW
=
(
iW
+
2
*
padW
-
dilatedKW
)
/
dW
+
1
;
auto
trInput1
=
input1
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
auto
trInput2
=
input2
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
const
int
threads
=
THREADS_FORWARD
;
const
dim3
blocks
(
batch_size
,
oH
,
oW
);
at
::
cuda
::
CUDAGuard
device_guard
(
input1
.
device
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input1
.
scalar_type
(),
"correlation_forward_cuda"
,
([
&
]{
TensorAcc4R
trInput1_acc
=
trInput1
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc4R
trInput2_acc
=
trInput2
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc5R
output_acc
=
output
.
packed_accessor32
<
scalar_t
,
5
,
RestrictPtrTraits
>
();
correlation_forward_cuda_kernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
trInput1_acc
,
trInput2_acc
,
output_acc
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
);
}));
}
void
CorrelationBackwardCUDAKernelLauncher
(
Tensor
grad_output
,
Tensor
input1
,
Tensor
input2
,
Tensor
grad_input1
,
Tensor
grad_input2
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
){
const
int
batch_size
=
input1
.
size
(
0
);
const
int
iH
=
input1
.
size
(
2
);
const
int
iW
=
input1
.
size
(
3
);
const
int
C
=
input1
.
size
(
1
);
const
dim3
blocks
(
C
,
iH
,
iW
);
const
dim3
threads
(
THREADS_BACKWARD
,
THREADS_BACKWARD
);
at
::
cuda
::
CUDAGuard
device_guard
(
input1
.
device
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input1
.
scalar_type
(),
"correlation_backward_cuda"
,
([
&
]{
TensorAcc4R
input1_acc
=
input1
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc4R
input2_acc
=
input2
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc4R
grad_input1_acc
=
grad_input1
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc4R
grad_input2_acc
=
grad_input2
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc5R
grad_output_acc
=
grad_output
.
packed_accessor32
<
scalar_t
,
5
,
RestrictPtrTraits
>
();
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
){
correlation_backward_cuda_kernel_input1
<
scalar_t
><<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_output_acc
,
input2_acc
,
grad_input1_acc
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
,
n
);
}
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
){
correlation_backward_cuda_kernel_input2
<
scalar_t
><<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_output_acc
,
input1_acc
,
grad_input2_acc
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
,
n
);
}
}));
}
mmcv/ops/csrc/pytorch/pybind.cpp
View file @
b92ea0b5
...
@@ -225,6 +225,23 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
...
@@ -225,6 +225,23 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
const
Tensor
&
argmax_idx
,
Tensor
grad_input
,
const
Tensor
&
argmax_idx
,
Tensor
grad_input
,
const
int
pool_size
);
const
int
pool_size
);
void
correlation_forward
(
Tensor
input1
,
Tensor
input2
,
Tensor
output
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
);
void
correlation_backward
(
Tensor
grad_output
,
Tensor
input1
,
Tensor
input2
,
Tensor
grad_input1
,
Tensor
grad_input2
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"upfirdn2d"
,
&
upfirdn2d
,
"upfirdn2d (CUDA)"
,
py
::
arg
(
"input"
),
m
.
def
(
"upfirdn2d"
,
&
upfirdn2d
,
"upfirdn2d (CUDA)"
,
py
::
arg
(
"input"
),
py
::
arg
(
"kernel"
),
py
::
arg
(
"up_x"
),
py
::
arg
(
"up_y"
),
py
::
arg
(
"down_x"
),
py
::
arg
(
"kernel"
),
py
::
arg
(
"up_x"
),
py
::
arg
(
"up_y"
),
py
::
arg
(
"down_x"
),
...
@@ -452,4 +469,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -452,4 +469,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"backward function of border_align"
,
py
::
arg
(
"grad_output"
),
"backward function of border_align"
,
py
::
arg
(
"grad_output"
),
py
::
arg
(
"boxes"
),
py
::
arg
(
"argmax_idx"
),
py
::
arg
(
"grad_input"
),
py
::
arg
(
"boxes"
),
py
::
arg
(
"argmax_idx"
),
py
::
arg
(
"grad_input"
),
py
::
arg
(
"pool_size"
));
py
::
arg
(
"pool_size"
));
m
.
def
(
"correlation_forward"
,
&
correlation_forward
,
"Correlation forward"
);
m
.
def
(
"correlation_backward"
,
&
correlation_backward
,
"Correlation backward"
);
}
}
tests/test_ops/test_corr.py
0 → 100644
View file @
b92ea0b5
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.ops
import
Correlation
_input1
=
[[[[
1.
,
2.
,
3.
],
[
0.
,
1.
,
2.
],
[
3.
,
5.
,
2.
]]]]
_input2
=
[[[[
1.
,
2.
,
3.
],
[
3.
,
1.
,
2.
],
[
8.
,
5.
,
2.
]]]]
_input2_2
=
[[[[
1.
,
2.
],
[
3.
,
1.
],
[
8.
,
5.
]]]]
gt_out_shape
=
(
1
,
1
,
1
,
3
,
3
)
_gt_out
=
[[[[[
1.
,
4.
,
9.
],
[
0.
,
1.
,
4.
],
[
24.
,
25.
,
4.
]]]]]
gt_input1_grad
=
[[[[
1.
,
2.
,
3.
],
[
3.
,
1.
,
2.
],
[
8.
,
5.
,
2.
]]]]
_ap_gt_out
=
[[[[[
1.
,
2.
,
3.
],
[
3.
,
1.
,
2.
],
[
8.
,
5.
,
2.
]],
[[
2.
,
4.
,
6.
],
[
6.
,
2.
,
4.
],
[
16.
,
10.
,
4.
]],
[[
3.
,
6.
,
9.
],
[
9.
,
3.
,
6.
],
[
24.
,
15.
,
6.
]]],
[[[
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
]],
[[
1.
,
2.
,
3.
],
[
3.
,
1.
,
2.
],
[
8.
,
5.
,
2.
]],
[[
2.
,
4.
,
6.
],
[
6.
,
2.
,
4.
],
[
16.
,
10.
,
4.
]]],
[[[
3.
,
6.
,
9.
],
[
9.
,
3.
,
6.
],
[
24.
,
15.
,
6.
]],
[[
5.
,
10.
,
15.
],
[
15.
,
5.
,
10.
],
[
40.
,
25.
,
10.
]],
[[
2.
,
4.
,
6.
],
[
6.
,
2.
,
4.
],
[
16.
,
10.
,
4.
]]]]]
def
assert_equal_tensor
(
tensor_a
,
tensor_b
):
assert
tensor_a
.
eq
(
tensor_b
).
all
()
class
TestCorrelation
:
def
_test_correlation
(
self
,
dtype
=
torch
.
float
):
layer
=
Correlation
(
max_displacement
=
0
)
input1
=
torch
.
tensor
(
_input1
,
dtype
=
dtype
).
cuda
()
input2
=
torch
.
tensor
(
_input2
,
dtype
=
dtype
).
cuda
()
input1
.
requires_grad
=
True
input2
.
requires_grad
=
True
out
=
layer
(
input1
,
input2
)
out
.
backward
(
torch
.
ones_like
(
out
))
gt_out
=
torch
.
tensor
(
_gt_out
,
dtype
=
dtype
)
assert_equal_tensor
(
out
.
cpu
(),
gt_out
)
assert_equal_tensor
(
input1
.
grad
.
detach
().
cpu
(),
input2
.
cpu
())
assert_equal_tensor
(
input2
.
grad
.
detach
().
cpu
(),
input1
.
cpu
())
def
test_correlation
(
self
):
self
.
_test_correlation
(
torch
.
float
)
self
.
_test_correlation
(
torch
.
double
)
self
.
_test_correlation
(
torch
.
half
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment