Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0bcbeadb
Unverified
Commit
0bcbeadb
authored
Dec 24, 2021
by
zhouyue
Committed by
GitHub
Dec 24, 2021
Browse files
[Feature] Add RiRoIAlignRotated CUDA op for rotated detection. (#1599)
parent
2475dc34
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
634 additions
and
5 deletions
+634
-5
docs/en/understand_mmcv/ops.md
docs/en/understand_mmcv/ops.md
+1
-0
docs/zh_cn/understand_mmcv/ops.md
docs/zh_cn/understand_mmcv/ops.md
+1
-0
mmcv/ops/__init__.py
mmcv/ops/__init__.py
+6
-5
mmcv/ops/csrc/common/cuda/riroi_align_rotated_cuda_kernel.cuh
.../ops/csrc/common/cuda/riroi_align_rotated_cuda_kernel.cuh
+242
-0
mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
+75
-0
mmcv/ops/csrc/pytorch/cuda/riroi_align_rotated_cuda.cu
mmcv/ops/csrc/pytorch/cuda/riroi_align_rotated_cuda.cu
+53
-0
mmcv/ops/csrc/pytorch/pybind.cpp
mmcv/ops/csrc/pytorch/pybind.cpp
+22
-0
mmcv/ops/csrc/pytorch/riroi_align_rotated.cpp
mmcv/ops/csrc/pytorch/riroi_align_rotated.cpp
+42
-0
mmcv/ops/riroi_align_rotated.py
mmcv/ops/riroi_align_rotated.py
+119
-0
tests/test_ops/test_riroi_align_rotated.py
tests/test_ops/test_riroi_align_rotated.py
+73
-0
No files found.
docs/en/understand_mmcv/ops.md
View file @
0bcbeadb
...
@@ -21,6 +21,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
...
@@ -21,6 +21,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
-
MaskedConv
-
MaskedConv
-
NMS
-
NMS
-
PSAMask
-
PSAMask
-
RiRoIAlignRotated
-
RotatedFeatureAlign
-
RotatedFeatureAlign
-
RoIPointPool3d
-
RoIPointPool3d
-
RoIPool
-
RoIPool
...
...
docs/zh_cn/understand_mmcv/ops.md
View file @
0bcbeadb
...
@@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
...
@@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
-
RotatedFeatureAlign
-
RotatedFeatureAlign
-
RoIPointPool3d
-
RoIPointPool3d
-
RoIPool
-
RoIPool
-
RiRoIAlignRotated
-
RoIAlign
-
RoIAlign
-
RoIAwarePool3d
-
RoIAwarePool3d
-
SimpleRoIAlign
-
SimpleRoIAlign
...
...
mmcv/ops/__init__.py
View file @
0bcbeadb
...
@@ -40,6 +40,7 @@ from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
...
@@ -40,6 +40,7 @@ from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part
)
points_in_boxes_part
)
from
.points_sampler
import
PointsSampler
from
.points_sampler
import
PointsSampler
from
.psa_mask
import
PSAMask
from
.psa_mask
import
PSAMask
from
.riroi_align_rotated
import
RiRoIAlignRotated
,
riroi_align_rotated
from
.roi_align
import
RoIAlign
,
roi_align
from
.roi_align
import
RoIAlign
,
roi_align
from
.roi_align_rotated
import
RoIAlignRotated
,
roi_align_rotated
from
.roi_align_rotated
import
RoIAlignRotated
,
roi_align_rotated
from
.roi_pool
import
RoIPool
,
roi_pool
from
.roi_pool
import
RoIPool
,
roi_pool
...
@@ -71,11 +72,11 @@ __all__ = [
...
@@ -71,11 +72,11 @@ __all__ = [
'SAConv2d'
,
'TINShift'
,
'tin_shift'
,
'assign_score_withk'
,
'SAConv2d'
,
'TINShift'
,
'tin_shift'
,
'assign_score_withk'
,
'box_iou_rotated'
,
'RoIPointPool3d'
,
'nms_rotated'
,
'knn'
,
'ball_query'
,
'box_iou_rotated'
,
'RoIPointPool3d'
,
'nms_rotated'
,
'knn'
,
'ball_query'
,
'upfirdn2d'
,
'FusedBiasLeakyReLU'
,
'fused_bias_leakyrelu'
,
'upfirdn2d'
,
'FusedBiasLeakyReLU'
,
'fused_bias_leakyrelu'
,
'rotated_feature_align'
,
'RoIAlignRotated'
,
'roi_align_rotated'
,
'rotated_feature_align'
,
'
Ri
RoIAlignRotated'
,
'
ri
roi_align_rotated'
,
'
pixel_group'
,
'QueryAndGroup'
,
'GroupAll'
,
'grouping_operation
'
,
'
RoIAlignRotated'
,
'roi_align_rotated'
,
'pixel_group'
,
'QueryAndGroup
'
,
'contour_expand'
,
'three_nn'
,
'three_interpolate'
,
'GroupAll'
,
'grouping_operation'
,
'contour_expand'
,
'three_nn'
,
'MultiScaleDeformableAttention'
,
'BorderAlign'
,
'border_align'
,
'three_interpolate'
,
'MultiScaleDeformableAttention'
,
'BorderAlign'
,
'gather_points'
,
'furthest_point_sample'
,
'border_align'
,
'gather_points'
,
'furthest_point_sample'
,
'furthest_point_sample_with_dist'
,
'PointsSampler'
,
'Correlation'
,
'furthest_point_sample_with_dist'
,
'PointsSampler'
,
'Correlation'
,
'boxes_iou_bev'
,
'nms_bev'
,
'nms_normal_bev'
,
'Voxelization'
,
'boxes_iou_bev'
,
'nms_bev'
,
'nms_normal_bev'
,
'Voxelization'
,
'voxelization'
,
'dynamic_scatter'
,
'DynamicScatter'
,
'RoIAwarePool3d'
,
'voxelization'
,
'dynamic_scatter'
,
'DynamicScatter'
,
'RoIAwarePool3d'
,
...
...
mmcv/ops/csrc/common/cuda/riroi_align_rotated_cuda_kernel.cuh
0 → 100644
View file @
0bcbeadb
// Modified from
// https://github.com/csuhan/ReDet/blob/master/mmdet/ops/riroi_align/src/riroi_align_kernel.cu
#ifndef RIROI_ALIGN_ROTATED_CUDA_KERNEL_CUH
#define RIROI_ALIGN_ROTATED_CUDA_KERNEL_CUH
#include <float.h>
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp"
#endif // MMCV_USE_PARROTS
/*** Forward ***/
template
<
typename
scalar_t
>
__global__
void
riroi_align_rotated_forward_cuda_kernel
(
const
int
nthreads
,
const
scalar_t
*
bottom_data
,
const
scalar_t
*
bottom_rois
,
const
scalar_t
spatial_scale
,
const
int
num_samples
,
const
bool
clockwise
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
num_orientations
,
scalar_t
*
top_data
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
// (n, c, ph, pw) is an element in the pooled output
int
pw
=
index
%
pooled_width
;
int
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
int
o
=
(
index
/
pooled_width
/
pooled_height
)
%
num_orientations
;
int
c
=
(
index
/
pooled_width
/
pooled_height
/
num_orientations
)
%
channels
;
int
n
=
index
/
pooled_width
/
pooled_height
/
num_orientations
/
channels
;
const
scalar_t
*
offset_bottom_rois
=
bottom_rois
+
n
*
6
;
int
roi_batch_ind
=
offset_bottom_rois
[
0
];
// Do not using rounding; this implementation detail is critical
scalar_t
roi_center_w
=
offset_bottom_rois
[
1
]
*
spatial_scale
;
scalar_t
roi_center_h
=
offset_bottom_rois
[
2
]
*
spatial_scale
;
scalar_t
roi_width
=
offset_bottom_rois
[
3
]
*
spatial_scale
;
scalar_t
roi_height
=
offset_bottom_rois
[
4
]
*
spatial_scale
;
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
scalar_t
theta
=
offset_bottom_rois
[
5
];
// Force malformed ROIs to be 1x1
roi_width
=
max
(
roi_width
,
(
scalar_t
)
1.
);
roi_height
=
max
(
roi_height
,
(
scalar_t
)
1.
);
scalar_t
bin_size_h
=
static_cast
<
scalar_t
>
(
roi_height
)
/
static_cast
<
scalar_t
>
(
pooled_height
);
scalar_t
bin_size_w
=
static_cast
<
scalar_t
>
(
roi_width
)
/
static_cast
<
scalar_t
>
(
pooled_width
);
// find aligned index
scalar_t
ind_float
=
theta
*
num_orientations
/
(
2
*
M_PI
);
int
ind
=
floor
(
ind_float
);
scalar_t
l_var
=
ind_float
-
(
scalar_t
)
ind
;
scalar_t
r_var
=
1.0
-
l_var
;
// correct start channel
ind
=
(
ind
+
num_orientations
)
%
num_orientations
;
// rotated channel
int
ind_rot
=
(
o
-
ind
+
num_orientations
)
%
num_orientations
;
int
ind_rot_plus
=
(
ind_rot
+
1
+
num_orientations
)
%
num_orientations
;
const
scalar_t
*
offset_bottom_data
=
bottom_data
+
(
roi_batch_ind
*
channels
*
num_orientations
+
c
*
num_orientations
+
ind_rot
)
*
height
*
width
;
const
scalar_t
*
offset_bottom_data_plus
=
bottom_data
+
(
roi_batch_ind
*
channels
*
num_orientations
+
c
*
num_orientations
+
ind_rot_plus
)
*
height
*
width
;
// We use roi_bin_grid to sample the grid and mimic integral
int
roi_bin_grid_h
=
(
num_samples
>
0
)
?
num_samples
:
ceilf
(
roi_height
/
pooled_height
);
// e.g., = 2
int
roi_bin_grid_w
=
(
num_samples
>
0
)
?
num_samples
:
ceilf
(
roi_width
/
pooled_width
);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
if
(
clockwise
)
{
theta
=
-
theta
;
// If clockwise, the angle needs to be reversed.
}
scalar_t
roi_start_h
=
-
roi_height
/
2.0
;
scalar_t
roi_start_w
=
-
roi_width
/
2.0
;
scalar_t
cosscalar_theta
=
cos
(
theta
);
scalar_t
sinscalar_theta
=
sin
(
theta
);
// We do average (integral) pooling inside a bin
const
scalar_t
count
=
max
(
roi_bin_grid_h
*
roi_bin_grid_w
,
1
);
// e.g. = 4
scalar_t
output_val
=
0.
;
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
{
// e.g., iy = 0, 1
const
scalar_t
yy
=
roi_start_h
+
ph
*
bin_size_h
+
static_cast
<
scalar_t
>
(
iy
+
.5
f
)
*
bin_size_h
/
static_cast
<
scalar_t
>
(
roi_bin_grid_h
);
// e.g., 0.5, 1.5
for
(
int
ix
=
0
;
ix
<
roi_bin_grid_w
;
ix
++
)
{
const
scalar_t
xx
=
roi_start_w
+
pw
*
bin_size_w
+
static_cast
<
scalar_t
>
(
ix
+
.5
f
)
*
bin_size_w
/
static_cast
<
scalar_t
>
(
roi_bin_grid_w
);
// Rotate by theta (counterclockwise) around the center and translate
scalar_t
y
=
yy
*
cosscalar_theta
-
xx
*
sinscalar_theta
+
roi_center_h
;
scalar_t
x
=
yy
*
sinscalar_theta
+
xx
*
cosscalar_theta
+
roi_center_w
;
scalar_t
val
=
bilinear_interpolate
<
scalar_t
>
(
offset_bottom_data
,
height
,
width
,
y
,
x
,
index
);
scalar_t
val_plus
=
bilinear_interpolate
<
scalar_t
>
(
offset_bottom_data_plus
,
height
,
width
,
y
,
x
,
index
);
output_val
+=
r_var
*
val
+
l_var
*
val_plus
;
}
}
output_val
/=
count
;
top_data
[
index
]
=
output_val
;
}
}
/*** Backward ***/
template
<
typename
scalar_t
>
__global__
void
riroi_align_rotated_backward_cuda_kernel
(
const
int
nthreads
,
const
scalar_t
*
top_diff
,
const
scalar_t
*
bottom_rois
,
const
scalar_t
spatial_scale
,
const
int
num_samples
,
const
bool
clockwise
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
num_orientations
,
scalar_t
*
bottom_diff
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
// (n, c, ph, pw) is an element in the pooled output
int
pw
=
index
%
pooled_width
;
int
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
int
o
=
(
index
/
pooled_width
/
pooled_height
)
%
num_orientations
;
int
c
=
(
index
/
pooled_width
/
pooled_height
/
num_orientations
)
%
channels
;
int
n
=
index
/
pooled_width
/
pooled_height
/
num_orientations
/
channels
;
const
scalar_t
*
offset_bottom_rois
=
bottom_rois
+
n
*
6
;
int
roi_batch_ind
=
offset_bottom_rois
[
0
];
// Do not round
scalar_t
roi_center_w
=
offset_bottom_rois
[
1
]
*
spatial_scale
;
scalar_t
roi_center_h
=
offset_bottom_rois
[
2
]
*
spatial_scale
;
scalar_t
roi_width
=
offset_bottom_rois
[
3
]
*
spatial_scale
;
scalar_t
roi_height
=
offset_bottom_rois
[
4
]
*
spatial_scale
;
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
scalar_t
theta
=
offset_bottom_rois
[
5
];
// Force malformed ROIs to be 1x1
roi_width
=
max
(
roi_width
,
(
scalar_t
)
1.
);
roi_height
=
max
(
roi_height
,
(
scalar_t
)
1.
);
scalar_t
bin_size_h
=
static_cast
<
scalar_t
>
(
roi_height
)
/
static_cast
<
scalar_t
>
(
pooled_height
);
scalar_t
bin_size_w
=
static_cast
<
scalar_t
>
(
roi_width
)
/
static_cast
<
scalar_t
>
(
pooled_width
);
// find aligned index
scalar_t
ind_float
=
theta
*
num_orientations
/
(
2
*
M_PI
);
int
ind
=
floor
(
ind_float
);
scalar_t
l_var
=
ind_float
-
(
scalar_t
)
ind
;
scalar_t
r_var
=
1.0
-
l_var
;
// correct start channel
ind
=
(
ind
+
num_orientations
)
%
num_orientations
;
// rotated channel
int
ind_rot
=
(
o
-
ind
+
num_orientations
)
%
num_orientations
;
int
ind_rot_plus
=
(
ind_rot
+
1
+
num_orientations
)
%
num_orientations
;
scalar_t
*
offset_bottom_diff
=
bottom_diff
+
(
roi_batch_ind
*
channels
*
num_orientations
+
c
*
num_orientations
+
ind_rot
)
*
height
*
width
;
scalar_t
*
offset_bottom_diff_plus
=
bottom_diff
+
(
roi_batch_ind
*
channels
*
num_orientations
+
c
*
num_orientations
+
ind_rot_plus
)
*
height
*
width
;
int
top_offset
=
(
n
*
channels
*
num_orientations
+
c
*
num_orientations
+
o
)
*
pooled_height
*
pooled_width
;
const
scalar_t
*
offset_top_diff
=
top_diff
+
top_offset
;
const
scalar_t
top_diff_this_bin
=
offset_top_diff
[
ph
*
pooled_width
+
pw
];
// We use roi_bin_grid to sample the grid and mimic integral
int
roi_bin_grid_h
=
(
num_samples
>
0
)
?
num_samples
:
ceilf
(
roi_height
/
pooled_height
);
// e.g., = 2
int
roi_bin_grid_w
=
(
num_samples
>
0
)
?
num_samples
:
ceilf
(
roi_width
/
pooled_width
);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
if
(
clockwise
)
{
theta
=
-
theta
;
// If clockwise, the angle needs to be reversed.
}
scalar_t
roi_start_h
=
-
roi_height
/
2.0
;
scalar_t
roi_start_w
=
-
roi_width
/
2.0
;
scalar_t
cosTheta
=
cos
(
theta
);
scalar_t
sinTheta
=
sin
(
theta
);
// We do average (integral) pooling inside a bin
const
scalar_t
count
=
roi_bin_grid_h
*
roi_bin_grid_w
;
// e.g. = 4
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
{
// e.g., iy = 0, 1
const
scalar_t
yy
=
roi_start_h
+
ph
*
bin_size_h
+
static_cast
<
scalar_t
>
(
iy
+
.5
f
)
*
bin_size_h
/
static_cast
<
scalar_t
>
(
roi_bin_grid_h
);
// e.g., 0.5, 1.5
for
(
int
ix
=
0
;
ix
<
roi_bin_grid_w
;
ix
++
)
{
const
scalar_t
xx
=
roi_start_w
+
pw
*
bin_size_w
+
static_cast
<
scalar_t
>
(
ix
+
.5
f
)
*
bin_size_w
/
static_cast
<
scalar_t
>
(
roi_bin_grid_w
);
// Rotate by theta around the center and translate
scalar_t
y
=
yy
*
cosTheta
-
xx
*
sinTheta
+
roi_center_h
;
scalar_t
x
=
yy
*
sinTheta
+
xx
*
cosTheta
+
roi_center_w
;
scalar_t
w1
,
w2
,
w3
,
w4
;
int
x_low
,
x_high
,
y_low
,
y_high
;
bilinear_interpolate_gradient
<
scalar_t
>
(
height
,
width
,
y
,
x
,
w1
,
w2
,
w3
,
w4
,
x_low
,
x_high
,
y_low
,
y_high
,
index
);
scalar_t
g1
=
top_diff_this_bin
*
w1
/
count
;
scalar_t
g2
=
top_diff_this_bin
*
w2
/
count
;
scalar_t
g3
=
top_diff_this_bin
*
w3
/
count
;
scalar_t
g4
=
top_diff_this_bin
*
w4
/
count
;
if
(
x_low
>=
0
&&
x_high
>=
0
&&
y_low
>=
0
&&
y_high
>=
0
)
{
atomicAdd
(
offset_bottom_diff
+
y_low
*
width
+
x_low
,
g1
*
r_var
);
atomicAdd
(
offset_bottom_diff
+
y_low
*
width
+
x_high
,
g2
*
r_var
);
atomicAdd
(
offset_bottom_diff
+
y_high
*
width
+
x_low
,
g3
*
r_var
);
atomicAdd
(
offset_bottom_diff
+
y_high
*
width
+
x_high
,
g4
*
r_var
);
atomicAdd
(
offset_bottom_diff_plus
+
y_low
*
width
+
x_low
,
g1
*
l_var
);
atomicAdd
(
offset_bottom_diff_plus
+
y_low
*
width
+
x_high
,
g2
*
l_var
);
atomicAdd
(
offset_bottom_diff_plus
+
y_high
*
width
+
x_low
,
g3
*
l_var
);
atomicAdd
(
offset_bottom_diff_plus
+
y_high
*
width
+
x_high
,
g4
*
l_var
);
}
// if
}
// ix
}
// iy
}
// CUDA_1D_KERNEL_LOOP
}
// RiRoIAlignBackward
#endif // RIROI_ALIGN_ROTATED_CUDA_KERNEL_CUH
mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
View file @
0bcbeadb
...
@@ -992,6 +992,81 @@ REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
...
@@ -992,6 +992,81 @@ REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
REGISTER_DEVICE_IMPL
(
roi_align_rotated_backward_impl
,
CUDA
,
REGISTER_DEVICE_IMPL
(
roi_align_rotated_backward_impl
,
CUDA
,
roi_align_rotated_backward_cuda
);
roi_align_rotated_backward_cuda
);
void
RiROIAlignRotatedForwardCUDAKernelLauncher
(
const
at
::
Tensor
features
,
const
at
::
Tensor
rois
,
const
float
spatial_scale
,
const
int
num_samples
,
const
bool
clockwise
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
num_rois
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
num_orientations
,
at
::
Tensor
output
);
void
RiROIAlignRotatedBackwardCUDAKernelLauncher
(
const
at
::
Tensor
top_grad
,
const
at
::
Tensor
rois
,
const
float
spatial_scale
,
const
int
num_samples
,
const
bool
clockwise
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
num_rois
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
num_orientations
,
at
::
Tensor
bottom_grad
);
void
riroi_align_rotated_forward_cuda
(
Tensor
features
,
Tensor
rois
,
Tensor
output
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
)
{
// Number of ROIs
int
num_rois
=
rois
.
size
(
0
);
int
size_rois
=
rois
.
size
(
1
);
if
(
size_rois
!=
6
)
{
AT_ERROR
(
"wrong roi size"
);
}
CHECK_CONTIGUOUS
(
features
);
CHECK_CONTIGUOUS
(
rois
);
int
num_channels
=
features
.
size
(
1
)
/
num_orientations
;
int
data_height
=
features
.
size
(
2
);
int
data_width
=
features
.
size
(
3
);
RiROIAlignRotatedForwardCUDAKernelLauncher
(
features
,
rois
,
spatial_scale
,
num_samples
,
clockwise
,
num_channels
,
data_height
,
data_width
,
num_rois
,
pooled_height
,
pooled_width
,
num_orientations
,
output
);
}
void
riroi_align_rotated_backward_cuda
(
Tensor
top_grad
,
Tensor
rois
,
Tensor
bottom_grad
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
)
{
// Number of ROIs
int
num_rois
=
rois
.
size
(
0
);
int
size_rois
=
rois
.
size
(
1
);
if
(
size_rois
!=
6
)
{
AT_ERROR
(
"wrong roi size"
);
}
CHECK_CONTIGUOUS
(
top_grad
);
CHECK_CONTIGUOUS
(
rois
);
int
num_channels
=
bottom_grad
.
size
(
1
)
/
num_orientations
;
int
data_height
=
bottom_grad
.
size
(
2
);
int
data_width
=
bottom_grad
.
size
(
3
);
RiROIAlignRotatedBackwardCUDAKernelLauncher
(
top_grad
,
rois
,
spatial_scale
,
num_samples
,
clockwise
,
num_channels
,
data_height
,
data_width
,
num_rois
,
pooled_height
,
pooled_width
,
num_orientations
,
bottom_grad
);
}
void
riroi_align_rotated_forward_impl
(
Tensor
features
,
Tensor
rois
,
Tensor
output
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
);
void
riroi_align_rotated_backward_impl
(
Tensor
top_grad
,
Tensor
rois
,
Tensor
bottom_grad
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
);
REGISTER_DEVICE_IMPL
(
riroi_align_rotated_forward_impl
,
CUDA
,
riroi_align_rotated_forward_cuda
);
REGISTER_DEVICE_IMPL
(
riroi_align_rotated_backward_impl
,
CUDA
,
riroi_align_rotated_backward_cuda
);
void
RoiawarePool3dForwardCUDAKernelLauncher
(
void
RoiawarePool3dForwardCUDAKernelLauncher
(
int
boxes_num
,
int
pts_num
,
int
channels
,
int
max_pts_each_voxel
,
int
out_x
,
int
boxes_num
,
int
pts_num
,
int
channels
,
int
max_pts_each_voxel
,
int
out_x
,
int
out_y
,
int
out_z
,
const
Tensor
rois
,
const
Tensor
pts
,
int
out_y
,
int
out_z
,
const
Tensor
rois
,
const
Tensor
pts
,
...
...
mmcv/ops/csrc/pytorch/cuda/riroi_align_rotated_cuda.cu
0 → 100644
View file @
0bcbeadb
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cuda_helper.hpp"
#include "riroi_align_rotated_cuda_kernel.cuh"
void
RiROIAlignRotatedForwardCUDAKernelLauncher
(
const
at
::
Tensor
features
,
const
at
::
Tensor
rois
,
const
float
spatial_scale
,
const
int
num_samples
,
const
bool
clockwise
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
num_rois
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
num_orientations
,
at
::
Tensor
output
)
{
const
int
output_size
=
num_rois
*
pooled_height
*
pooled_width
*
channels
*
num_orientations
;
at
::
cuda
::
CUDAGuard
device_guard
(
features
.
device
());
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
features
.
scalar_type
(),
"riroi_align_rotated_forward_cuda_kernel"
,
([
&
]
{
const
scalar_t
*
bottom_data
=
features
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
rois_data
=
rois
.
data_ptr
<
scalar_t
>
();
scalar_t
*
top_data
=
output
.
data_ptr
<
scalar_t
>
();
riroi_align_rotated_forward_cuda_kernel
<
scalar_t
>
<<<
GET_BLOCKS
(
output_size
),
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
output_size
,
bottom_data
,
rois_data
,
scalar_t
(
spatial_scale
),
num_samples
,
clockwise
,
channels
,
height
,
width
,
pooled_height
,
pooled_width
,
num_orientations
,
top_data
);
}));
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
RiROIAlignRotatedBackwardCUDAKernelLauncher
(
const
at
::
Tensor
top_grad
,
const
at
::
Tensor
rois
,
const
float
spatial_scale
,
const
int
num_samples
,
const
bool
clockwise
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
num_rois
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
num_orientations
,
at
::
Tensor
bottom_grad
)
{
const
int
output_size
=
num_rois
*
pooled_height
*
pooled_width
*
channels
*
num_orientations
;
at
::
cuda
::
CUDAGuard
device_guard
(
top_grad
.
device
());
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
top_grad
.
scalar_type
(),
"riroi_align_rotated_backward_cuda_kernel"
,
([
&
]
{
const
scalar_t
*
top_diff
=
top_grad
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
rois_data
=
rois
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bottom_diff
=
bottom_grad
.
data_ptr
<
scalar_t
>
();
riroi_align_rotated_backward_cuda_kernel
<
scalar_t
>
<<<
GET_BLOCKS
(
output_size
),
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
output_size
,
top_diff
,
rois_data
,
spatial_scale
,
num_samples
,
clockwise
,
channels
,
height
,
width
,
pooled_height
,
pooled_width
,
num_orientations
,
bottom_diff
);
}));
AT_CUDA_CHECK
(
cudaGetLastError
());
}
mmcv/ops/csrc/pytorch/pybind.cpp
View file @
0bcbeadb
...
@@ -350,6 +350,17 @@ void rotated_feature_align_backward(const Tensor top_grad,
...
@@ -350,6 +350,17 @@ void rotated_feature_align_backward(const Tensor top_grad,
const
float
spatial_scale
,
const
float
spatial_scale
,
const
int
points
);
const
int
points
);
void
riroi_align_rotated_forward
(
Tensor
features
,
Tensor
rois
,
Tensor
output
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
);
void
riroi_align_rotated_backward
(
Tensor
top_grad
,
Tensor
rois
,
Tensor
bottom_grad
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"upfirdn2d"
,
&
upfirdn2d
,
"upfirdn2d (CUDA)"
,
py
::
arg
(
"input"
),
m
.
def
(
"upfirdn2d"
,
&
upfirdn2d
,
"upfirdn2d (CUDA)"
,
py
::
arg
(
"input"
),
py
::
arg
(
"kernel"
),
py
::
arg
(
"up_x"
),
py
::
arg
(
"up_y"
),
py
::
arg
(
"down_x"
),
py
::
arg
(
"kernel"
),
py
::
arg
(
"up_x"
),
py
::
arg
(
"up_y"
),
py
::
arg
(
"down_x"
),
...
@@ -704,4 +715,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -704,4 +715,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Feature Refine backward (CUDA)"
,
py
::
arg
(
"top_grad"
),
"Feature Refine backward (CUDA)"
,
py
::
arg
(
"top_grad"
),
py
::
arg
(
"best_bboxes"
),
py
::
arg
(
"bottom_grad"
),
py
::
arg
(
"best_bboxes"
),
py
::
arg
(
"bottom_grad"
),
py
::
arg
(
"spatial_scale"
),
py
::
arg
(
"points"
));
py
::
arg
(
"spatial_scale"
),
py
::
arg
(
"points"
));
m
.
def
(
"riroi_align_rotated_forward"
,
&
riroi_align_rotated_forward
,
"riroi_align_rotated forward"
,
py
::
arg
(
"features"
),
py
::
arg
(
"rois"
),
py
::
arg
(
"output"
),
py
::
arg
(
"pooled_height"
),
py
::
arg
(
"pooled_width"
),
py
::
arg
(
"spatial_scale"
),
py
::
arg
(
"num_samples"
),
py
::
arg
(
"num_orientations"
),
py
::
arg
(
"clockwise"
));
m
.
def
(
"riroi_align_rotated_backward"
,
&
riroi_align_rotated_backward
,
"riroi_align_rotated backward"
,
py
::
arg
(
"top_grad"
),
py
::
arg
(
"rois"
),
py
::
arg
(
"bottom_grad"
),
py
::
arg
(
"pooled_height"
),
py
::
arg
(
"pooled_width"
),
py
::
arg
(
"spatial_scale"
),
py
::
arg
(
"num_samples"
),
py
::
arg
(
"num_orientations"
),
py
::
arg
(
"clockwise"
));
}
}
mmcv/ops/csrc/pytorch/riroi_align_rotated.cpp
0 → 100644
View file @
0bcbeadb
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void
riroi_align_rotated_forward_impl
(
Tensor
features
,
Tensor
rois
,
Tensor
output
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
)
{
DISPATCH_DEVICE_IMPL
(
riroi_align_rotated_forward_impl
,
features
,
rois
,
output
,
pooled_height
,
pooled_width
,
spatial_scale
,
num_samples
,
num_orientations
,
clockwise
);
}
void
riroi_align_rotated_backward_impl
(
Tensor
top_grad
,
Tensor
rois
,
Tensor
bottom_grad
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
)
{
DISPATCH_DEVICE_IMPL
(
riroi_align_rotated_backward_impl
,
top_grad
,
rois
,
bottom_grad
,
pooled_height
,
pooled_width
,
spatial_scale
,
num_samples
,
num_orientations
,
clockwise
);
}
void
riroi_align_rotated_forward
(
Tensor
features
,
Tensor
rois
,
Tensor
output
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
)
{
riroi_align_rotated_forward_impl
(
features
,
rois
,
output
,
pooled_height
,
pooled_width
,
spatial_scale
,
num_samples
,
num_orientations
,
clockwise
);
}
void
riroi_align_rotated_backward
(
Tensor
top_grad
,
Tensor
rois
,
Tensor
bottom_grad
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
num_samples
,
int
num_orientations
,
bool
clockwise
)
{
riroi_align_rotated_backward_impl
(
top_grad
,
rois
,
bottom_grad
,
pooled_height
,
pooled_width
,
spatial_scale
,
num_samples
,
num_orientations
,
clockwise
);
}
mmcv/ops/riroi_align_rotated.py
0 → 100644
View file @
0bcbeadb
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
torch.autograd
import
Function
from
..utils
import
ext_loader
,
is_tuple_of
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'riroi_align_rotated_forward'
,
'riroi_align_rotated_backward'
])
class
RiRoIAlignRotatedFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
features
,
rois
,
out_size
,
spatial_scale
,
num_samples
=
0
,
num_orientations
=
8
,
clockwise
=
False
):
if
isinstance
(
out_size
,
int
):
out_h
=
out_size
out_w
=
out_size
elif
is_tuple_of
(
out_size
,
int
):
assert
len
(
out_size
)
==
2
out_h
,
out_w
=
out_size
else
:
raise
TypeError
(
f
'"out_size" should be an integer or tuple of integers,'
f
' but got
{
out_size
}
'
)
ctx
.
spatial_scale
=
spatial_scale
ctx
.
num_samples
=
num_samples
ctx
.
num_orientations
=
num_orientations
ctx
.
clockwise
=
clockwise
ctx
.
save_for_backward
(
rois
)
ctx
.
feature_size
=
features
.
size
()
batch_size
,
num_channels
,
_
,
_
=
features
.
size
()
num_rois
=
rois
.
size
(
0
)
output
=
features
.
new_zeros
(
num_rois
,
num_channels
,
out_h
,
out_w
)
ext_module
.
riroi_align_rotated_forward
(
features
,
rois
,
output
,
out_h
,
out_w
,
spatial_scale
,
num_samples
,
num_orientations
,
clockwise
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
feature_size
=
ctx
.
feature_size
spatial_scale
=
ctx
.
spatial_scale
num_orientations
=
ctx
.
num_orientations
clockwise
=
ctx
.
clockwise
num_samples
=
ctx
.
num_samples
rois
=
ctx
.
saved_tensors
[
0
]
assert
feature_size
is
not
None
batch_size
,
num_channels
,
feature_h
,
feature_w
=
feature_size
out_w
=
grad_output
.
size
(
3
)
out_h
=
grad_output
.
size
(
2
)
grad_input
=
grad_rois
=
None
if
ctx
.
needs_input_grad
[
0
]:
grad_input
=
rois
.
new_zeros
(
batch_size
,
num_channels
,
feature_h
,
feature_w
)
ext_module
.
riroi_align_rotated_backward
(
grad_output
.
contiguous
(),
rois
,
grad_input
,
out_h
,
out_w
,
spatial_scale
,
num_samples
,
num_orientations
,
clockwise
)
return
grad_input
,
grad_rois
,
None
,
None
,
None
,
None
,
None
riroi_align_rotated
=
RiRoIAlignRotatedFunction
.
apply
class
RiRoIAlignRotated
(
nn
.
Module
):
"""Rotation-invariant RoI align pooling layer for rotated proposals.
It accepts a feature map of shape (N, C, H, W) and rois with shape
(n, 6) with each roi decoded as (batch_index, center_x, center_y,
w, h, angle). The angle is in radian.
The details are described in the paper `ReDet: A Rotation-equivariant
Detector for Aerial Object Detection <https://arxiv.org/abs/2103.07733>`_.
Args:
out_size (tuple): fixed dimensional RoI output with shape (h, w).
spatial_scale (float): scale the input boxes by this number
num_samples (int): number of inputs samples to take for each
output sample. 0 to take samples densely for current models.
num_orientations (int): number of oriented channels.
clockwise (bool): If True, the angle in each proposal follows a
clockwise fashion in image space, otherwise, the angle is
counterclockwise. Default: False.
"""
def
__init__
(
self
,
out_size
,
spatial_scale
,
num_samples
=
0
,
num_orientations
=
8
,
clockwise
=
False
):
super
(
RiRoIAlignRotated
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
spatial_scale
=
float
(
spatial_scale
)
self
.
num_samples
=
int
(
num_samples
)
self
.
num_orientations
=
int
(
num_orientations
)
self
.
clockwise
=
clockwise
def
forward
(
self
,
features
,
rois
):
return
RiRoIAlignRotatedFunction
.
apply
(
features
,
rois
,
self
.
out_size
,
self
.
spatial_scale
,
self
.
num_samples
,
self
.
num_orientations
,
self
.
clockwise
)
tests/test_ops/test_riroi_align_rotated.py
0 → 100644
View file @
0bcbeadb
import
numpy
as
np
import
pytest
import
torch
from
torch.autograd
import
gradcheck
from
mmcv.ops
import
RiRoIAlignRotated
np_feature
=
np
.
array
([[[[
1
,
2
],
[
3
,
4
]],
[[
1
,
2
],
[
4
,
3
]],
[[
4
,
3
],
[
2
,
1
]],
[[
1
,
2
],
[
5
,
6
]],
[[
3
,
4
],
[
7
,
8
]],
[[
9
,
10
],
[
13
,
14
]],
[[
11
,
12
],
[
15
,
16
]],
[[
1
,
1
],
[
2
,
2
]]]])
np_rois
=
np
.
array
([[
0.
,
0.5
,
0.5
,
1.
,
1.
,
np
.
pi
/
3
],
[
0.
,
1.
,
1.
,
3.
,
3.
,
np
.
pi
/
2
]])
expect_output
=
np
.
array
([[[[
1.8425
,
1.3516
],
[
2.3151
,
1.8241
]],
[[
2.4779
,
1.7416
],
[
3.2173
,
2.5632
]],
[[
2.7149
,
2.2638
],
[
2.6540
,
2.3673
]],
[[
2.9461
,
2.8638
],
[
2.8028
,
2.7205
]],
[[
4.1943
,
2.7214
],
[
5.6119
,
4.1391
]],
[[
7.5276
,
6.0547
],
[
8.9453
,
7.4724
]],
[[
12.1943
,
10.7214
],
[
13.6119
,
12.1391
]],
[[
9.5489
,
8.4237
],
[
10.5763
,
9.4511
]]],
[[[
7.6562
,
12.5625
],
[
4.0000
,
6.6250
]],
[[
1.0000
,
1.3125
],
[
0.5000
,
0.6562
]],
[[
1.6562
,
1.9375
],
[
1.0000
,
1.3125
]],
[[
1.8438
,
2.0547
],
[
0.7500
,
1.1562
]],
[[
0.8438
,
3.0625
],
[
0.2500
,
1.1875
]],
[[
2.6562
,
2.5625
],
[
1.5000
,
1.6250
]],
[[
3.6562
,
4.5625
],
[
2.0000
,
2.6250
]],
[[
6.6562
,
10.5625
],
[
3.5000
,
5.6250
]]]])
expect_grad
=
np
.
array
([[[[
1.4727
,
1.5586
],
[
1.5586
,
1.6602
]],
[[
1.4727
,
1.5586
],
[
1.5586
,
1.6602
]],
[[
1.4727
,
1.5586
],
[
1.5586
,
1.6602
]],
[[
1.4727
,
1.5586
],
[
1.5586
,
1.6602
]],
[[
1.4727
,
1.5586
],
[
1.5586
,
1.6602
]],
[[
1.4727
,
1.5586
],
[
1.5586
,
1.6602
]],
[[
1.4727
,
1.5586
],
[
1.5586
,
1.6602
]],
[[
1.4727
,
1.5586
],
[
1.5586
,
1.6602
]]]])
pool_h
=
2
pool_w
=
2
spatial_scale
=
1.0
num_samples
=
2
sampling_ratio
=
2
num_orientations
=
8
clockwise
=
False
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_roialign_rotated_gradcheck
():
x
=
torch
.
tensor
(
np_feature
,
dtype
=
torch
.
float
,
device
=
'cuda'
,
requires_grad
=
True
)
rois
=
torch
.
tensor
(
np_rois
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
froipool
=
RiRoIAlignRotated
((
pool_h
,
pool_w
),
spatial_scale
,
num_samples
,
num_orientations
,
clockwise
)
gradcheck
(
froipool
,
(
x
,
rois
),
eps
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_roialign_rotated_allclose
():
x
=
torch
.
tensor
(
np_feature
,
dtype
=
torch
.
float
,
device
=
'cuda'
,
requires_grad
=
True
)
rois
=
torch
.
tensor
(
np_rois
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
froipool
=
RiRoIAlignRotated
((
pool_h
,
pool_w
),
spatial_scale
,
num_samples
,
num_orientations
,
clockwise
)
output
=
froipool
(
x
,
rois
)
output
.
backward
(
torch
.
ones_like
(
output
))
assert
np
.
allclose
(
output
.
data
.
type
(
torch
.
float
).
cpu
().
numpy
(),
expect_output
,
atol
=
1e-3
)
assert
np
.
allclose
(
x
.
grad
.
data
.
type
(
torch
.
float
).
cpu
().
numpy
(),
expect_grad
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment