Unverified Commit b83bdb0c authored by Hakjin Lee's avatar Hakjin Lee Committed by GitHub
Browse files

[Refactor] Refactor the interface for RoIAlignRotated (#1662)

* fix interface for RoIAlignRotated

* Add a unit test for RoIAlignRotated

* Make a unit test for RoIAlignRotated concise

* fix interface for RoIAlignRotated

* Refactor ext_module.nms_rotated

* Lint cpp files
parent fccb1091
...@@ -20,7 +20,7 @@ template <typename scalar_t> ...@@ -20,7 +20,7 @@ template <typename scalar_t>
__global__ void roi_align_rotated_forward_cuda_kernel( __global__ void roi_align_rotated_forward_cuda_kernel(
const int nthreads, const scalar_t *bottom_data, const int nthreads, const scalar_t *bottom_data,
const scalar_t *bottom_rois, const scalar_t spatial_scale, const scalar_t *bottom_rois, const scalar_t spatial_scale,
const int sample_num, const bool aligned, const bool clockwise, const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, scalar_t *top_data) { const int pooled_height, const int pooled_width, scalar_t *top_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
...@@ -58,11 +58,11 @@ __global__ void roi_align_rotated_forward_cuda_kernel( ...@@ -58,11 +58,11 @@ __global__ void roi_align_rotated_forward_cuda_kernel(
bottom_data + (roi_batch_ind * channels + c) * height * width; bottom_data + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sample_num > 0) int roi_bin_grid_h = (sampling_ratio > 0)
? sample_num ? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2 : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = int roi_bin_grid_w =
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after. // Appropriate translation needs to be applied after.
...@@ -104,7 +104,7 @@ __global__ void roi_align_rotated_forward_cuda_kernel( ...@@ -104,7 +104,7 @@ __global__ void roi_align_rotated_forward_cuda_kernel(
template <typename scalar_t> template <typename scalar_t>
__global__ void roi_align_rotated_backward_cuda_kernel( __global__ void roi_align_rotated_backward_cuda_kernel(
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois, const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
const scalar_t spatial_scale, const int sample_num, const bool aligned, const scalar_t spatial_scale, const int sampling_ratio, const bool aligned,
const bool clockwise, const int channels, const int height, const int width, const bool clockwise, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, scalar_t *bottom_diff) { const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
...@@ -146,11 +146,11 @@ __global__ void roi_align_rotated_backward_cuda_kernel( ...@@ -146,11 +146,11 @@ __global__ void roi_align_rotated_backward_cuda_kernel(
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sample_num > 0) int roi_bin_grid_h = (sampling_ratio > 0)
? sample_num ? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2 : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = int roi_bin_grid_w =
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after. // Appropriate translation needs to be applied after.
......
...@@ -14,14 +14,14 @@ void roi_align_rotated_forward_cuda_parrots(CudaContext& ctx, ...@@ -14,14 +14,14 @@ void roi_align_rotated_forward_cuda_parrots(CudaContext& ctx,
int pooled_height; int pooled_height;
int pooled_width; int pooled_width;
float spatial_scale; float spatial_scale;
int sample_num; int sampling_ratio;
bool aligned; bool aligned;
bool clockwise; bool clockwise;
SSAttrs(attr) SSAttrs(attr)
.get<int>("pooled_height", pooled_height) .get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width) .get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale) .get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num) .get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned) .get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise) .get<bool>("clockwise", clockwise)
.done(); .done();
...@@ -30,7 +30,7 @@ void roi_align_rotated_forward_cuda_parrots(CudaContext& ctx, ...@@ -30,7 +30,7 @@ void roi_align_rotated_forward_cuda_parrots(CudaContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]); const auto& rois = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]); auto output = buildATensor(ctx, outs[0]);
roi_align_rotated_forward_cuda(input, rois, output, pooled_height, roi_align_rotated_forward_cuda(input, rois, output, pooled_height,
pooled_width, spatial_scale, sample_num, pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise); aligned, clockwise);
} }
...@@ -41,14 +41,14 @@ void roi_align_rotated_backward_cuda_parrots(CudaContext& ctx, ...@@ -41,14 +41,14 @@ void roi_align_rotated_backward_cuda_parrots(CudaContext& ctx,
int pooled_height; int pooled_height;
int pooled_width; int pooled_width;
float spatial_scale; float spatial_scale;
int sample_num; int sampling_ratio;
bool aligned; bool aligned;
bool clockwise; bool clockwise;
SSAttrs(attr) SSAttrs(attr)
.get<int>("pooled_height", pooled_height) .get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width) .get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale) .get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num) .get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned) .get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise) .get<bool>("clockwise", clockwise)
.done(); .done();
...@@ -57,7 +57,7 @@ void roi_align_rotated_backward_cuda_parrots(CudaContext& ctx, ...@@ -57,7 +57,7 @@ void roi_align_rotated_backward_cuda_parrots(CudaContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]); const auto& rois = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]); auto grad_input = buildATensor(ctx, outs[0]);
roi_align_rotated_backward_cuda(grad_output, rois, grad_input, pooled_height, roi_align_rotated_backward_cuda(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale, sample_num, pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise); aligned, clockwise);
} }
#endif #endif
...@@ -69,14 +69,14 @@ void roi_align_rotated_forward_cpu_parrots(HostContext& ctx, ...@@ -69,14 +69,14 @@ void roi_align_rotated_forward_cpu_parrots(HostContext& ctx,
int pooled_height; int pooled_height;
int pooled_width; int pooled_width;
float spatial_scale; float spatial_scale;
int sample_num; int sampling_ratio;
bool aligned; bool aligned;
bool clockwise; bool clockwise;
SSAttrs(attr) SSAttrs(attr)
.get<int>("pooled_height", pooled_height) .get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width) .get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale) .get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num) .get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned) .get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise) .get<bool>("clockwise", clockwise)
.done(); .done();
...@@ -85,7 +85,7 @@ void roi_align_rotated_forward_cpu_parrots(HostContext& ctx, ...@@ -85,7 +85,7 @@ void roi_align_rotated_forward_cpu_parrots(HostContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]); const auto& rois = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]); auto output = buildATensor(ctx, outs[0]);
roi_align_rotated_forward_cpu(input, rois, output, pooled_height, roi_align_rotated_forward_cpu(input, rois, output, pooled_height,
pooled_width, spatial_scale, sample_num, pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise); aligned, clockwise);
} }
...@@ -96,14 +96,14 @@ void roi_align_rotated_backward_cpu_parrots(HostContext& ctx, ...@@ -96,14 +96,14 @@ void roi_align_rotated_backward_cpu_parrots(HostContext& ctx,
int pooled_height; int pooled_height;
int pooled_width; int pooled_width;
float spatial_scale; float spatial_scale;
int sample_num; int sampling_ratio;
bool aligned; bool aligned;
bool clockwise; bool clockwise;
SSAttrs(attr) SSAttrs(attr)
.get<int>("pooled_height", pooled_height) .get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width) .get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale) .get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num) .get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned) .get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise) .get<bool>("clockwise", clockwise)
.done(); .done();
...@@ -112,7 +112,7 @@ void roi_align_rotated_backward_cpu_parrots(HostContext& ctx, ...@@ -112,7 +112,7 @@ void roi_align_rotated_backward_cpu_parrots(HostContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]); const auto& rois = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]); auto grad_input = buildATensor(ctx, outs[0]);
roi_align_rotated_backward_cpu(grad_output, rois, grad_input, pooled_height, roi_align_rotated_backward_cpu(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale, sample_num, pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise); aligned, clockwise);
} }
...@@ -120,7 +120,7 @@ PARROTS_EXTENSION_REGISTER(roi_align_rotated_forward) ...@@ -120,7 +120,7 @@ PARROTS_EXTENSION_REGISTER(roi_align_rotated_forward)
.attr("pooled_height") .attr("pooled_height")
.attr("pooled_width") .attr("pooled_width")
.attr("spatial_scale") .attr("spatial_scale")
.attr("sample_num") .attr("sampling_ratio")
.attr("aligned") .attr("aligned")
.attr("clockwise") .attr("clockwise")
.input(2) .input(2)
...@@ -135,7 +135,7 @@ PARROTS_EXTENSION_REGISTER(roi_align_rotated_backward) ...@@ -135,7 +135,7 @@ PARROTS_EXTENSION_REGISTER(roi_align_rotated_backward)
.attr("pooled_height") .attr("pooled_height")
.attr("pooled_width") .attr("pooled_width")
.attr("spatial_scale") .attr("spatial_scale")
.attr("sample_num") .attr("sampling_ratio")
.attr("aligned") .attr("aligned")
.attr("clockwise") .attr("clockwise")
.input(2) .input(2)
......
...@@ -5,27 +5,27 @@ ...@@ -5,27 +5,27 @@
using namespace at; using namespace at;
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output, void roi_align_rotated_forward_cuda(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width, int pooled_height, int pooled_width,
float spatial_scale, int sample_num, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise); bool aligned, bool clockwise);
void roi_align_rotated_backward_cuda(Tensor grad_output, Tensor rois, void roi_align_rotated_backward_cuda(Tensor grad_output, Tensor rois,
Tensor bottom_grad, int pooled_height, Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale, int pooled_width, float spatial_scale,
int sample_num, bool aligned, int sampling_ratio, bool aligned,
bool clockwise); bool clockwise);
#endif #endif
void roi_align_rotated_forward_cpu(Tensor features, Tensor rois, Tensor output, void roi_align_rotated_forward_cpu(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width, int pooled_height, int pooled_width,
float spatial_scale, int sample_num, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise); bool aligned, bool clockwise);
void roi_align_rotated_backward_cpu(Tensor grad_output, Tensor rois, void roi_align_rotated_backward_cpu(Tensor grad_output, Tensor rois,
Tensor bottom_grad, int pooled_height, Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale, int pooled_width, float spatial_scale,
int sample_num, bool aligned, int sampling_ratio, bool aligned,
bool clockwise); bool clockwise);
#endif // ROI_ALIGN_ROTATED_PYTORCH_H #endif // ROI_ALIGN_ROTATED_PYTORCH_H
...@@ -439,15 +439,15 @@ void roi_align_rotated_backward_cpu(Tensor top_grad, Tensor rois, ...@@ -439,15 +439,15 @@ void roi_align_rotated_backward_cpu(Tensor top_grad, Tensor rois,
sampling_ratio, aligned, clockwise); sampling_ratio, aligned, clockwise);
} }
void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output, void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width, int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise); bool aligned, bool clockwise);
void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois, void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height, Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale, int aligned_width, float spatial_scale,
int sample_ratio, bool aligned, int sampling_ratio, bool aligned,
bool clockwise); bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CPU, REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CPU,
roi_align_rotated_forward_cpu); roi_align_rotated_forward_cpu);
......
...@@ -924,20 +924,20 @@ REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda); ...@@ -924,20 +924,20 @@ REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
REGISTER_DEVICE_IMPL(roi_align_backward_impl, CUDA, roi_align_backward_cuda); REGISTER_DEVICE_IMPL(roi_align_backward_impl, CUDA, roi_align_backward_cuda);
void ROIAlignRotatedForwardCUDAKernelLauncher( void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale, const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise, const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois, const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output); const int pooled_height, const int pooled_width, at::Tensor output);
void ROIAlignRotatedBackwardCUDAKernelLauncher( void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale, const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise, const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois, const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad); const int pooled_height, const int pooled_width, at::Tensor bottom_grad);
void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output, void roi_align_rotated_forward_cuda(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width, int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise) { bool aligned, bool clockwise) {
// Number of ROIs // Number of ROIs
int num_rois = rois.size(0); int num_rois = rois.size(0);
...@@ -947,11 +947,11 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output, ...@@ -947,11 +947,11 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
AT_ERROR("wrong roi size"); AT_ERROR("wrong roi size");
} }
int num_channels = features.size(1); int num_channels = input.size(1);
int data_height = features.size(2); int data_height = input.size(2);
int data_width = features.size(3); int data_width = input.size(3);
ROIAlignRotatedForwardCUDAKernelLauncher( ROIAlignRotatedForwardCUDAKernelLauncher(
features, rois, spatial_scale, sample_ratio, aligned, clockwise, input, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height, num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, output); aligned_width, output);
} }
...@@ -959,7 +959,7 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output, ...@@ -959,7 +959,7 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois, void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height, Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale, int aligned_width, float spatial_scale,
int sample_ratio, bool aligned, int sampling_ratio, bool aligned,
bool clockwise) { bool clockwise) {
// Number of ROIs // Number of ROIs
int num_rois = rois.size(0); int num_rois = rois.size(0);
...@@ -972,20 +972,20 @@ void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois, ...@@ -972,20 +972,20 @@ void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
int data_height = bottom_grad.size(2); int data_height = bottom_grad.size(2);
int data_width = bottom_grad.size(3); int data_width = bottom_grad.size(3);
ROIAlignRotatedBackwardCUDAKernelLauncher( ROIAlignRotatedBackwardCUDAKernelLauncher(
top_grad, rois, spatial_scale, sample_ratio, aligned, clockwise, top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height, num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, bottom_grad); aligned_width, bottom_grad);
} }
void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output, void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width, int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise); bool aligned, bool clockwise);
void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois, void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height, Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale, int aligned_width, float spatial_scale,
int sample_ratio, bool aligned, int sampling_ratio, bool aligned,
bool clockwise); bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA, REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
roi_align_rotated_forward_cuda); roi_align_rotated_forward_cuda);
......
...@@ -3,21 +3,21 @@ ...@@ -3,21 +3,21 @@
#include "roi_align_rotated_cuda_kernel.cuh" #include "roi_align_rotated_cuda_kernel.cuh"
void ROIAlignRotatedForwardCUDAKernelLauncher( void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale, const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise, const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois, const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output) { const int pooled_height, const int pooled_width, at::Tensor output) {
const int output_size = num_rois * pooled_height * pooled_width * channels; const int output_size = num_rois * pooled_height * pooled_width * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] { input.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] {
const scalar_t *bottom_data = features.data_ptr<scalar_t>(); const scalar_t *bottom_data = input.data_ptr<scalar_t>();
const scalar_t *rois_data = rois.data_ptr<scalar_t>(); const scalar_t *rois_data = rois.data_ptr<scalar_t>();
scalar_t *top_data = output.data_ptr<scalar_t>(); scalar_t *top_data = output.data_ptr<scalar_t>();
roi_align_rotated_forward_cuda_kernel<scalar_t> roi_align_rotated_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale), output_size, bottom_data, rois_data, scalar_t(spatial_scale),
sample_num, aligned, clockwise, channels, height, width, sampling_ratio, aligned, clockwise, channels, height, width,
pooled_height, pooled_width, top_data); pooled_height, pooled_width, top_data);
})); }));
...@@ -26,7 +26,7 @@ void ROIAlignRotatedForwardCUDAKernelLauncher( ...@@ -26,7 +26,7 @@ void ROIAlignRotatedForwardCUDAKernelLauncher(
void ROIAlignRotatedBackwardCUDAKernelLauncher( void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale, const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise, const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois, const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad) { const int pooled_height, const int pooled_width, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels; const int output_size = num_rois * pooled_height * pooled_width * channels;
...@@ -37,7 +37,7 @@ void ROIAlignRotatedBackwardCUDAKernelLauncher( ...@@ -37,7 +37,7 @@ void ROIAlignRotatedBackwardCUDAKernelLauncher(
scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>(); scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
roi_align_rotated_backward_cuda_kernel<scalar_t> roi_align_rotated_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, top_diff, rois_data, spatial_scale, sample_num, output_size, top_diff, rois_data, spatial_scale, sampling_ratio,
aligned, clockwise, channels, height, width, pooled_height, aligned, clockwise, channels, height, width, pooled_height,
pooled_width, bottom_diff); pooled_width, bottom_diff);
})); }));
......
...@@ -312,13 +312,14 @@ Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias, ...@@ -312,13 +312,14 @@ Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias,
void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output, void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width, int pooled_height, int pooled_width,
float spatial_scale, int sample_num, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise); bool aligned, bool clockwise);
void roi_align_rotated_backward(Tensor grad_output, Tensor rois, void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height, Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale, int pooled_width, float spatial_scale,
int sample_num, bool aligned, bool clockwise); int sampling_ratio, bool aligned,
bool clockwise);
std::vector<torch::Tensor> dynamic_point_to_voxel_forward( std::vector<torch::Tensor> dynamic_point_to_voxel_forward(
const torch::Tensor &feats, const torch::Tensor &coors, const torch::Tensor &feats, const torch::Tensor &coors,
...@@ -736,13 +737,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -736,13 +737,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roi_align_rotated_forward", &roi_align_rotated_forward, m.def("roi_align_rotated_forward", &roi_align_rotated_forward,
"roi_align_rotated forward", py::arg("input"), py::arg("rois"), "roi_align_rotated forward", py::arg("input"), py::arg("rois"),
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"), py::arg("spatial_scale"), py::arg("sampling_ratio"), py::arg("aligned"),
py::arg("clockwise")); py::arg("clockwise"));
m.def("roi_align_rotated_backward", &roi_align_rotated_backward, m.def("roi_align_rotated_backward", &roi_align_rotated_backward,
"roi_align_rotated backward", py::arg("rois"), py::arg("grad_input"), "roi_align_rotated backward", py::arg("rois"), py::arg("grad_input"),
py::arg("grad_output"), py::arg("pooled_height"), py::arg("grad_output"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"), py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise")); py::arg("sampling_ratio"), py::arg("aligned"), py::arg("clockwise"));
m.def("dynamic_point_to_voxel_forward", &dynamic_point_to_voxel_forward, m.def("dynamic_point_to_voxel_forward", &dynamic_point_to_voxel_forward,
"dynamic_point_to_voxel_forward", py::arg("feats"), py::arg("coors"), "dynamic_point_to_voxel_forward", py::arg("feats"), py::arg("coors"),
py::arg("reduce_type")); py::arg("reduce_type"));
......
...@@ -2,23 +2,23 @@ ...@@ -2,23 +2,23 @@
#include "pytorch_cpp_helper.hpp" #include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp" #include "pytorch_device_registry.hpp"
void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output, void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width, int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio, float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise) { bool aligned, bool clockwise) {
DISPATCH_DEVICE_IMPL(roi_align_rotated_forward_impl, features, rois, output, DISPATCH_DEVICE_IMPL(roi_align_rotated_forward_impl, input, rois, output,
aligned_height, aligned_width, spatial_scale, aligned_height, aligned_width, spatial_scale,
sample_ratio, aligned, clockwise); sampling_ratio, aligned, clockwise);
} }
void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois, void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height, Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale, int aligned_width, float spatial_scale,
int sample_ratio, bool aligned, int sampling_ratio, bool aligned,
bool clockwise) { bool clockwise) {
DISPATCH_DEVICE_IMPL(roi_align_rotated_backward_impl, top_grad, rois, DISPATCH_DEVICE_IMPL(roi_align_rotated_backward_impl, top_grad, rois,
bottom_grad, aligned_height, aligned_width, bottom_grad, aligned_height, aligned_width,
spatial_scale, sample_ratio, aligned, clockwise); spatial_scale, sampling_ratio, aligned, clockwise);
} }
void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output, void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.nn.modules.utils import _pair
from ..utils import ext_loader from ..utils import deprecated_api_warning, ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward']) '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward'])
...@@ -11,80 +12,67 @@ ext_module = ext_loader.load_ext( ...@@ -11,80 +12,67 @@ ext_module = ext_loader.load_ext(
class RoIAlignRotatedFunction(Function): class RoIAlignRotatedFunction(Function):
@staticmethod @staticmethod
def symbolic(g, features, rois, out_size, spatial_scale, sample_num, def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
aligned, clockwise): aligned, clockwise):
if isinstance(out_size, int): if isinstance(output_size, int):
out_h = out_size out_h = output_size
out_w = out_size out_w = output_size
elif isinstance(out_size, tuple): elif isinstance(output_size, tuple):
assert len(out_size) == 2 assert len(output_size) == 2
assert isinstance(out_size[0], int) assert isinstance(output_size[0], int)
assert isinstance(out_size[1], int) assert isinstance(output_size[1], int)
out_h, out_w = out_size out_h, out_w = output_size
else: else:
raise TypeError( raise TypeError(
'"out_size" must be an integer or tuple of integers') '"output_size" must be an integer or tuple of integers')
return g.op( return g.op(
'mmcv::MMCVRoIAlignRotated', 'mmcv::MMCVRoIAlignRotated',
features, input,
rois, rois,
output_height_i=out_h, output_height_i=out_h,
output_width_i=out_h, output_width_i=out_h,
spatial_scale_f=spatial_scale, spatial_scale_f=spatial_scale,
sampling_ratio_i=sample_num, sampling_ratio_i=sampling_ratio,
aligned_i=aligned, aligned_i=aligned,
clockwise_i=clockwise) clockwise_i=clockwise)
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
features, input,
rois, rois,
out_size, output_size,
spatial_scale, spatial_scale,
sample_num=0, sampling_ratio=0,
aligned=True, aligned=True,
clockwise=False): clockwise=False):
if isinstance(out_size, int): ctx.output_size = _pair(output_size)
out_h = out_size
out_w = out_size
elif isinstance(out_size, tuple):
assert len(out_size) == 2
assert isinstance(out_size[0], int)
assert isinstance(out_size[1], int)
out_h, out_w = out_size
else:
raise TypeError(
'"out_size" must be an integer or tuple of integers')
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.sample_num = sample_num ctx.sampling_ratio = sampling_ratio
ctx.aligned = aligned ctx.aligned = aligned
ctx.clockwise = clockwise ctx.clockwise = clockwise
ctx.save_for_backward(rois) ctx.save_for_backward(rois)
ctx.feature_size = features.size() ctx.feature_size = input.size()
batch_size, num_channels, data_height, data_width = features.size() batch_size, num_channels, data_height, data_width = input.size()
num_rois = rois.size(0) num_rois = rois.size(0)
output = features.new_zeros(num_rois, num_channels, out_h, out_w) output = input.new_zeros(num_rois, num_channels, ctx.output_size[0],
ctx.output_size[1])
ext_module.roi_align_rotated_forward( ext_module.roi_align_rotated_forward(
features, input,
rois, rois,
output, output,
pooled_height=out_h, pooled_height=ctx.output_size[0],
pooled_width=out_w, pooled_width=ctx.output_size[1],
spatial_scale=spatial_scale, spatial_scale=ctx.spatial_scale,
sample_num=sample_num, sampling_ratio=ctx.sampling_ratio,
aligned=aligned, aligned=ctx.aligned,
clockwise=clockwise) clockwise=ctx.clockwise)
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
feature_size = ctx.feature_size feature_size = ctx.feature_size
spatial_scale = ctx.spatial_scale
aligned = ctx.aligned
clockwise = ctx.clockwise
sample_num = ctx.sample_num
rois = ctx.saved_tensors[0] rois = ctx.saved_tensors[0]
assert feature_size is not None assert feature_size is not None
batch_size, num_channels, data_height, data_width = feature_size batch_size, num_channels, data_height, data_width = feature_size
...@@ -103,10 +91,10 @@ class RoIAlignRotatedFunction(Function): ...@@ -103,10 +91,10 @@ class RoIAlignRotatedFunction(Function):
grad_input, grad_input,
pooled_height=out_h, pooled_height=out_h,
pooled_width=out_w, pooled_width=out_w,
spatial_scale=spatial_scale, spatial_scale=ctx.spatial_scale,
sample_num=sample_num, sampling_ratio=ctx.sampling_ratio,
aligned=aligned, aligned=ctx.aligned,
clockwise=clockwise) clockwise=ctx.clockwise)
return grad_input, grad_rois, None, None, None, None, None return grad_input, grad_rois, None, None, None, None, None
...@@ -121,9 +109,9 @@ class RoIAlignRotated(nn.Module): ...@@ -121,9 +109,9 @@ class RoIAlignRotated(nn.Module):
w, h, angle). The angle is in radian. w, h, angle). The angle is in radian.
Args: Args:
out_size (tuple): h, w output_size (tuple): h, w
spatial_scale (float): scale the input boxes by this number spatial_scale (float): scale the input boxes by this number
sample_num (int): number of inputs samples to take for each sampling_ratio(int): number of inputs samples to take for each
output sample. 0 to take samples densely for current models. output sample. 0 to take samples densely for current models.
aligned (bool): if False, use the legacy implementation in aligned (bool): if False, use the legacy implementation in
MMDetection. If True, align the results more perfectly. MMDetection. If True, align the results more perfectly.
...@@ -156,22 +144,37 @@ class RoIAlignRotated(nn.Module): ...@@ -156,22 +144,37 @@ class RoIAlignRotated(nn.Module):
performance if ROIAlign is used together with conv layers. performance if ROIAlign is used together with conv layers.
""" """
@deprecated_api_warning(
{
'out_size': 'output_size',
'sample_num': 'sampling_ratio'
},
cls_name='RoIAlignRotated')
def __init__(self, def __init__(self,
out_size, output_size,
spatial_scale, spatial_scale,
sample_num=0, sampling_ratio=0,
aligned=True, aligned=True,
clockwise=False): clockwise=False):
super(RoIAlignRotated, self).__init__() super(RoIAlignRotated, self).__init__()
self.out_size = out_size self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num) self.sampling_ratio = int(sampling_ratio)
self.aligned = aligned self.aligned = aligned
self.clockwise = clockwise self.clockwise = clockwise
def forward(self, features, rois): def forward(self, input, rois):
return RoIAlignRotatedFunction.apply(features, rois, self.out_size, return RoIAlignRotatedFunction.apply(input, rois, self.output_size,
self.spatial_scale, self.spatial_scale,
self.sample_num, self.aligned, self.sampling_ratio, self.aligned,
self.clockwise) self.clockwise)
def __repr__(self):
s = self.__class__.__name__
s += f'(output_size={self.output_size}, '
s += f'spatial_scale={self.spatial_scale}, '
s += f'sampling_ratio={self.sampling_ratio}, '
s += f'aligned={self.aligned}, '
s += f'clockwise={self.clockwise})'
return s
...@@ -80,7 +80,7 @@ def _test_roialign_rotated_allclose(device, dtype): ...@@ -80,7 +80,7 @@ def _test_roialign_rotated_allclose(device, dtype):
if not torch.cuda.is_available() and device == 'cuda': if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('unittest does not support GPU yet.') pytest.skip('unittest does not support GPU yet.')
try: try:
from mmcv.ops import roi_align_rotated from mmcv.ops import roi_align_rotated, RoIAlignRotated
except ModuleNotFoundError: except ModuleNotFoundError:
pytest.skip('test requires compilation') pytest.skip('test requires compilation')
pool_h = 2 pool_h = 2
...@@ -106,6 +106,25 @@ def _test_roialign_rotated_allclose(device, dtype): ...@@ -106,6 +106,25 @@ def _test_roialign_rotated_allclose(device, dtype):
assert np.allclose( assert np.allclose(
x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3) x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3)
# Test deprecated parameters
roi_align_rotated_module_deprecated = RoIAlignRotated(
out_size=(pool_h, pool_w),
spatial_scale=spatial_scale,
sample_num=sampling_ratio)
output_1 = roi_align_rotated_module_deprecated(x, rois)
roi_align_rotated_module_new = RoIAlignRotated(
output_size=(pool_h, pool_w),
spatial_scale=spatial_scale,
sampling_ratio=sampling_ratio)
output_2 = roi_align_rotated_module_new(x, rois)
assert np.allclose(
output_1.data.type(torch.float).cpu().numpy(),
output_2.data.type(torch.float).cpu().numpy())
@pytest.mark.parametrize('device', ['cuda', 'cpu']) @pytest.mark.parametrize('device', ['cuda', 'cpu'])
@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half]) @pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment