Unverified Commit 35affc35 authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

[Feature] Support PrRoIPool operation (#2113)



* support prroipooling

* fix build bug

* fix bug of precision

* fix lint

* add ut

* add license and citation

* polish op and add type hint

* refactor ut

* add docstring for prroipool

* polish code

* format cu and cuh code

* fix format

* fix format

* fix typo
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 22fadcee
......@@ -57,3 +57,4 @@ We implement common ops used in detection, segmentation, etc.
| TINShift | | √ | √ | |
| UpFirDn2d | | √ | | |
| Voxelization | √ | √ | | |
| PrRoIPool | | √ | | |
......@@ -57,3 +57,4 @@ MMCV 提供了检测、分割等任务中常用的算子
| TINShift | | √ | √ | |
| UpFirDn2d | | √ | | |
| Voxelization | √ | √ | | |
| PrRoIPool | | √ | | |
......@@ -46,6 +46,7 @@ from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .points_in_polygons import points_in_polygons
from .points_sampler import PointsSampler
from .prroi_pool import PrRoIPool, prroi_pool
from .psa_mask import PSAMask
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
from .roi_align import RoIAlign, roi_align
......@@ -100,5 +101,6 @@ __all__ = [
'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance'
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool'
]
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/vacancy/PreciseRoIPooling/blob/master/src/prroi_pooling_gpu_impl.cu
// Distributed under terms of the MIT license.
#ifndef PRROI_POOL_CUDA_KERNEL_CUH
#define PRROI_POOL_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__device__ static __forceinline__ T PrRoIPoolingGetData(const T *data,
const int h,
const int w,
const int height,
const int width) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
T retVal = overflow ? 0.0f : data[h * width + w];
return retVal;
}
template <typename T>
__device__ static __forceinline__ T PrRoIPoolingGetCoeff(T dh, T dw) {
return (1.0f - abs(dh)) * (1.0f - abs(dw));
}
template <typename T>
__device__ static __forceinline__ T PrRoIPoolingSingleCoorIntegral(T s, T t,
T c1, T c2) {
return 0.5 * (t * t - s * s) * (c2 - c1) + (t - s) * c1;
}
template <typename T>
__device__ static T PrRoIPoolingInterpolation(const T *data, const T h,
const T w, const int height,
const int width) {
T retVal = 0.0f;
int h1 = floorf(h);
int w1 = floorf(w);
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
h1 = floorf(h) + 1;
w1 = floorf(w);
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
h1 = floorf(h);
w1 = floorf(w) + 1;
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
h1 = floorf(h) + 1;
w1 = floorf(w) + 1;
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
return retVal;
}
template <typename T>
__device__ static T PrRoIPoolingMatCalculation(const T *this_data,
const int s_h, const int s_w,
const int e_h, const int e_w,
const T y0, const T x0,
const T y1, const T x1,
const int h0, const int w0) {
T alpha, beta, lim_alpha, lim_beta, tmp;
T sum_out = 0;
alpha = x0 - T(s_w);
beta = y0 - T(s_h);
lim_alpha = x1 - T(s_w);
lim_beta = y1 - T(s_h);
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp;
alpha = T(e_w) - x1;
lim_alpha = T(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp;
alpha = x0 - T(s_w);
beta = T(e_h) - y1;
lim_alpha = x1 - T(s_w);
lim_beta = T(e_h) - y0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp;
alpha = T(e_w) - x1;
lim_alpha = T(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp;
return sum_out;
}
template <typename T>
__device__ static void PrRoIPoolingDistributeDiff(T *diff, const T top_diff,
const int h, const int w,
const int height,
const int width,
const T coeff) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
if (!overflow) atomicAdd(diff + h * width + w, top_diff * coeff);
}
template <typename T>
__device__ static void PrRoIPoolingMatDistributeDiff(
T *diff, const T top_diff, const int s_h, const int s_w, const int e_h,
const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0,
const int w0) {
T alpha, beta, lim_alpha, lim_beta, tmp;
alpha = x0 - T(s_w);
beta = y0 - T(s_h);
lim_alpha = x1 - T(s_w);
lim_beta = y1 - T(s_h);
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp);
alpha = T(e_w) - x1;
lim_alpha = T(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp);
alpha = x0 - T(s_w);
beta = T(e_h) - y1;
lim_alpha = x1 - T(s_w);
lim_beta = T(e_h) - y0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp);
alpha = T(e_w) - x1;
lim_alpha = T(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp);
}
template <typename T>
__global__ void prroi_pool_forward_cuda_kernel(
const int nthreads, const T *input, const T *rois, T *output,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int channels, const int height, const int width) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T *offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
T roi_x1 = offset_rois[1] * spatial_scale;
T roi_y1 = offset_rois[2] * spatial_scale;
T roi_x2 = offset_rois[3] * spatial_scale;
T roi_y2 = offset_rois[4] * spatial_scale;
T roi_width = max(roi_x2 - roi_x1, ((T)0.0));
T roi_height = max(roi_y2 - roi_y1, ((T)0.0));
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
const T *this_data =
input + (roi_batch_ind * channels + c) * height * width;
T *this_out = output + index;
T bin_x1 = roi_x1 + bin_size_w * pw;
T bin_y1 = roi_y1 + bin_size_h * ph;
T bin_x2 = bin_x1 + bin_size_w;
T bin_y2 = bin_y1 + bin_size_h;
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
if (bin_size == 0) {
*this_out = 0;
continue;
}
T sum_out = 0;
int start_x, start_y, end_x, end_y;
start_x = floorf(bin_x1);
end_x = ceilf(bin_x2);
start_y = floorf(bin_y1);
end_y = ceilf(bin_y2);
for (int bin_x = start_x; bin_x < end_x; ++bin_x)
for (int bin_y = start_y; bin_y < end_y; ++bin_y)
sum_out += PrRoIPoolingMatCalculation(
this_data, bin_y, bin_x, bin_y + 1, bin_x + 1,
max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)),
min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height,
width);
*this_out = sum_out / bin_size;
}
}
template <typename T>
__global__ void prroi_pool_backward_cuda_kernel(
const int nthreads, const T *grad_output, const T *rois, T *grad_input,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int channels, const int height, const int width) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
rois += n * 5;
int roi_batch_ind = rois[0];
T roi_x1 = rois[1] * spatial_scale;
T roi_y1 = rois[2] * spatial_scale;
T roi_x2 = rois[3] * spatial_scale;
T roi_y2 = rois[4] * spatial_scale;
T roi_width = max(roi_x2 - roi_x1, (T)0);
T roi_height = max(roi_y2 - roi_y1, (T)0);
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
const T *this_out_grad = grad_output + index;
T *this_data_grad =
grad_input + (roi_batch_ind * channels + c) * height * width;
T bin_x1 = roi_x1 + bin_size_w * pw;
T bin_y1 = roi_y1 + bin_size_h * ph;
T bin_x2 = bin_x1 + bin_size_w;
T bin_y2 = bin_y1 + bin_size_h;
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
T sum_out = bin_size == T(0) ? T(0) : *this_out_grad / bin_size;
int start_x, start_y, end_x, end_y;
start_x = floorf(bin_x1);
end_x = ceilf(bin_x2);
start_y = floorf(bin_y1);
end_y = ceilf(bin_y2);
for (int bin_x = start_x; bin_x < end_x; ++bin_x)
for (int bin_y = start_y; bin_y < end_y; ++bin_y)
PrRoIPoolingMatDistributeDiff(
this_data_grad, sum_out, bin_y, bin_x, bin_y + 1, bin_x + 1,
max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)),
min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height,
width);
}
}
template <typename T>
__global__ void prroi_pool_coor_backward_cuda_kernel(
const int nthreads, const T *output, const T *grad_output, const T *input,
const T *rois, T *grad_rois, const int pooled_height,
const int pooled_width, const T spatial_scale, const int channels,
const int height, const int width) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
rois += n * 5;
int roi_batch_ind = rois[0];
T roi_x1 = rois[1] * spatial_scale;
T roi_y1 = rois[2] * spatial_scale;
T roi_x2 = rois[3] * spatial_scale;
T roi_y2 = rois[4] * spatial_scale;
T roi_width = max(roi_x2 - roi_x1, (T)0);
T roi_height = max(roi_y2 - roi_y1, (T)0);
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
const T output_grad_val = grad_output[index];
const T *this_input_data =
input + (roi_batch_ind * channels + c) * height * width;
const T output_val = output[index];
T *this_rois_grad = grad_rois + n * 5;
T bin_x1 = roi_x1 + bin_size_w * pw;
T bin_y1 = roi_y1 + bin_size_h * ph;
T bin_x2 = bin_x1 + bin_size_w;
T bin_y2 = bin_y1 + bin_size_h;
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
T sum_out = bin_size == T(0) ? T(0) : output_grad_val / bin_size;
// WARNING: to be discussed
if (sum_out == 0) return;
int start_x, start_y, end_x, end_y;
start_x = floorf(bin_x1);
end_x = ceilf(bin_x2);
start_y = floorf(bin_y1);
end_y = ceilf(bin_y2);
T grad_x1_y = 0, grad_x2_y = 0, grad_x_y1 = 0, grad_x_y2 = 0;
for (int bin_y = start_y; bin_y < end_y; ++bin_y) {
grad_x1_y += PrRoIPoolingSingleCoorIntegral(
max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y,
PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x1,
height, width),
PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x1,
height, width));
grad_x2_y += PrRoIPoolingSingleCoorIntegral(
max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y,
PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x2,
height, width),
PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x2,
height, width));
}
for (int bin_x = start_x; bin_x < end_x; ++bin_x) {
grad_x_y1 += PrRoIPoolingSingleCoorIntegral(
max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x,
PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x),
height, width),
PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x + 1),
height, width));
grad_x_y2 += PrRoIPoolingSingleCoorIntegral(
max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x,
PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x),
height, width),
PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x + 1),
height, width));
}
T partial_x1 = -grad_x1_y + (bin_y2 - bin_y1) * output_val;
T partial_y1 = -grad_x_y1 + (bin_x2 - bin_x1) * output_val;
T partial_x2 = grad_x2_y - (bin_y2 - bin_y1) * output_val;
T partial_y2 = grad_x_y2 - (bin_x2 - bin_x1) * output_val;
partial_x1 = partial_x1 / bin_size * spatial_scale;
partial_x2 = partial_x2 / bin_size * spatial_scale;
partial_y1 = partial_y1 / bin_size * spatial_scale;
partial_y2 = partial_y2 / bin_size * spatial_scale;
// (index, x1, y1, x2, y2)
this_rois_grad[0] = 0;
atomicAdd(this_rois_grad + 1,
(partial_x1 * (1.0f - T(pw) / pooled_width) +
partial_x2 * (1.0f - T(pw + 1) / pooled_width)) *
output_grad_val);
atomicAdd(this_rois_grad + 2,
(partial_y1 * (1.0f - T(ph) / pooled_height) +
partial_y2 * (1.0f - T(ph + 1) / pooled_height)) *
output_grad_val);
atomicAdd(this_rois_grad + 3, (partial_x2 * T(pw + 1) / pooled_width +
partial_x1 * T(pw) / pooled_width) *
output_grad_val);
atomicAdd(this_rois_grad + 4, (partial_y2 * T(ph + 1) / pooled_height +
partial_y1 * T(ph) / pooled_height) *
output_grad_val);
}
}
#endif // ROI_POOL_CUDA_KERNEL_CUH
......@@ -1737,3 +1737,54 @@ REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
chamfer_distance_forward_cuda);
REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA,
chamfer_distance_backward_cuda);
void PrROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale);
void PrROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
void PrROIPoolCoorBackwardCUDAKernelLauncher(
Tensor output, Tensor grad_output, Tensor input, Tensor rois,
Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale);
void prroi_pool_forward_cuda(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale) {
PrROIPoolForwardCUDAKernelLauncher(input, rois, output, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_backward_cuda(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
PrROIPoolBackwardCUDAKernelLauncher(grad_output, rois, grad_input,
pooled_height, pooled_width,
spatial_scale);
}
void prroi_pool_coor_backward_cuda(Tensor output, Tensor grad_output,
Tensor input, Tensor rois, Tensor grad_rois,
int pooled_height, int pooled_width,
float spatial_scale) {
PrROIPoolCoorBackwardCUDAKernelLauncher(output, grad_output, input, rois,
grad_rois, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale);
void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
Tensor input, Tensor rois, Tensor grad_rois,
int pooled_height, int pooled_width,
float spatial_scale);
REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, CUDA, prroi_pool_forward_cuda);
REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, CUDA, prroi_pool_backward_cuda);
REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, CUDA,
prroi_pool_coor_backward_cuda);
// Copyright (c) OpenMMLab. All rights reserved
#include "prroi_pool_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void PrROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
prroi_pool_forward_cuda_kernel<float>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<float>(), rois.data_ptr<float>(),
output.data_ptr<float>(), pooled_height, pooled_width,
static_cast<float>(spatial_scale), channels, height, width);
AT_CUDA_CHECK(cudaGetLastError());
}
void PrROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width,
float spatial_scale) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
prroi_pool_backward_cuda_kernel<float>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<float>(), rois.data_ptr<float>(),
grad_input.data_ptr<float>(), pooled_height, pooled_width,
static_cast<float>(spatial_scale), channels, height, width);
AT_CUDA_CHECK(cudaGetLastError());
}
void PrROIPoolCoorBackwardCUDAKernelLauncher(Tensor output, Tensor grad_output,
Tensor input, Tensor rois,
Tensor grad_rois,
int pooled_height,
int pooled_width,
float spatial_scale) {
int output_size = grad_output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
prroi_pool_coor_backward_cuda_kernel<float>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, output.data_ptr<float>(), grad_output.data_ptr<float>(),
input.data_ptr<float>(), rois.data_ptr<float>(),
grad_rois.data_ptr<float>(), pooled_height, pooled_width,
static_cast<float>(spatial_scale), channels, height, width);
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale) {
DISPATCH_DEVICE_IMPL(prroi_pool_forward_impl, input, rois, output,
pooled_height, pooled_width, spatial_scale);
}
void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
DISPATCH_DEVICE_IMPL(prroi_pool_backward_impl, grad_output, rois, grad_input,
pooled_height, pooled_width, spatial_scale);
}
void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
Tensor input, Tensor rois, Tensor grad_rois,
int pooled_height, int pooled_width,
float spatial_scale) {
DISPATCH_DEVICE_IMPL(prroi_pool_coor_backward_impl, output, grad_output,
input, rois, grad_rois, pooled_height, pooled_width,
spatial_scale);
}
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale) {
prroi_pool_forward_impl(input, rois, output, pooled_height, pooled_width,
spatial_scale);
}
void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale) {
prroi_pool_backward_impl(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input,
Tensor rois, Tensor grad_rois, int pooled_height,
int pooled_width, float spatial_scale) {
prroi_pool_coor_backward_impl(output, grad_output, input, rois, grad_rois,
pooled_height, pooled_width, spatial_scale);
}
......@@ -240,6 +240,18 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
Tensor idx_tensor, int b, int n, int m,
float min_radius, float max_radius, int nsample);
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale);
void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale);
void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input,
Tensor rois, Tensor grad_rois, int pooled_height,
int pooled_width, float spatial_scale);
template <unsigned NDim>
std::vector<torch::Tensor> get_indice_pairs_forward(
torch::Tensor indices, int64_t batchSize,
......@@ -828,4 +840,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"chamfer_distance_backward", py::arg("xyz1"), py::arg("xyz2"),
py::arg("gradxyz1"), py::arg("gradxyz2"), py::arg("graddist1"),
py::arg("graddist2"), py::arg("idx1"), py::arg("idx2"));
m.def("prroi_pool_forward", &prroi_pool_forward, "prroi_pool forward",
py::arg("input"), py::arg("rois"), py::arg("output"),
py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"));
m.def("prroi_pool_backward", &prroi_pool_backward, "prroi_pool_backward",
py::arg("grad_output"), py::arg("rois"), py::arg("grad_input"),
py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"));
m.def("prroi_pool_coor_backward", &prroi_pool_coor_backward,
"prroi_pool_coor_backward", py::arg("output"), py::arg("grad_output"),
py::arg("input"), py::arg("rois"), py::arg("grad_rois"),
py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"));
}
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext',
['prroi_pool_forward', 'prroi_pool_backward', 'prroi_pool_coor_backward'])
class PrRoIPoolFunction(Function):
@staticmethod
def symbolic(g, features, rois, output_size, spatial_scale):
return g.op(
'mmcv::PrRoIPool',
features,
rois,
pooled_height_i=int(output_size[0]),
pooled_width_i=int(output_size[1]),
spatial_scale_f=float(spatial_scale))
@staticmethod
def forward(ctx,
features: torch.Tensor,
rois: torch.Tensor,
output_size: Tuple,
spatial_scale: float = 1.0) -> torch.Tensor:
if 'FloatTensor' not in features.type(
) or 'FloatTensor' not in rois.type():
raise ValueError(
'Precise RoI Pooling only takes float input, got '
f'{features.type()} for features and {rois.type()} for rois.')
pooled_height = int(output_size[0])
pooled_width = int(output_size[1])
spatial_scale = float(spatial_scale)
features = features.contiguous()
rois = rois.contiguous()
output_shape = (rois.size(0), features.size(1), pooled_height,
pooled_width)
output = features.new_zeros(output_shape)
params = (pooled_height, pooled_width, spatial_scale)
ext_module.prroi_pool_forward(features, rois, output, *params)
ctx.params = params
# everything here is contiguous.
ctx.save_for_backward(features, rois, output)
return output
@staticmethod
@once_differentiable
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
features, rois, output = ctx.saved_tensors
grad_input = grad_output.new_zeros(*features.shape)
grad_coor = grad_output.new_zeros(*rois.shape)
if features.requires_grad:
grad_output = grad_output.contiguous()
ext_module.prroi_pool_backward(grad_output, rois, grad_input,
*ctx.params)
if rois.requires_grad:
grad_output = grad_output.contiguous()
ext_module.prroi_pool_coor_backward(output, grad_output, features,
rois, grad_coor, *ctx.params)
return grad_input, grad_coor, None, None, None
prroi_pool = PrRoIPoolFunction.apply
class PrRoIPool(nn.Module):
"""The operation of precision RoI pooling. The implementation of PrRoIPool
is modified from https://github.com/vacancy/PreciseRoIPooling/
Precise RoI Pooling (PrRoIPool) is an integration-based (bilinear
interpolation) average pooling method for RoI Pooling. It avoids any
quantization and has a continuous gradient on bounding box coordinates.
It is:
1. different from the original RoI Pooling proposed in Fast R-CNN. PrRoI
Pooling uses average pooling instead of max pooling for each bin and has a
continuous gradient on bounding box coordinates. That is, one can take the
derivatives of some loss function w.r.t the coordinates of each RoI and
optimize the RoI coordinates.
2. different from the RoI Align proposed in Mask R-CNN. PrRoI Pooling uses
a full integration-based average pooling instead of sampling a constant
number of points. This makes the gradient w.r.t. the coordinates
continuous.
Args:
output_size (Union[int, tuple]): h, w.
spatial_scale (float, optional): scale the input boxes by this number.
Defaults to 1.0.
"""
def __init__(self,
output_size: Union[int, tuple],
spatial_scale: float = 1.0):
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
def forward(self, features: torch.Tensor,
rois: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
features (torch.Tensor): The feature map.
rois (torch.Tensor): The RoI bboxes in [tl_x, tl_y, br_x, br_y]
format.
Returns:
torch.Tensor: The pooled results.
"""
return prroi_pool(features, rois, self.output_size, self.spatial_scale)
def __repr__(self):
s = self.__class__.__name__
s += f'(output_size={self.output_size}, '
s += f'spatial_scale={self.spatial_scale})'
return s
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck
_USING_PARROTS = False
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]], [[4., 3.], [2.,
1.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
outputs = [
([[[[1.75, 2.25], [2.75, 3.25]]]], [[[[1., 1.],
[1., 1.]]]], [[0., 2., 4., 2., 4.]]),
([[[[1.75, 2.25], [2.75, 3.25]],
[[3.25, 2.75], [2.25, 1.75]]]], [[[[1., 1.], [1., 1.]],
[[1., 1.],
[1., 1.]]]], [[0., 0., 0., 0., 0.]]),
([[[[3.75, 6.91666651],
[10.08333302,
13.25]]]], [[[[0.11111111, 0.22222224, 0.22222222, 0.11111111],
[0.22222224, 0.444444448, 0.44444448, 0.22222224],
[0.22222224, 0.44444448, 0.44444448, 0.22222224],
[0.11111111, 0.22222224, 0.22222224, 0.11111111]]]],
[[0.0, 3.33333302, 6.66666603, 3.33333349, 6.66666698]])
]
class TestPrRoiPool:
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
])
def test_roipool_gradcheck(self, device):
from mmcv.ops import PrRoIPool
pool_h = 2
pool_w = 2
spatial_scale = 1.0
for case in inputs:
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
x = torch.tensor(np_input, device=device, requires_grad=True)
rois = torch.tensor(np_rois, device=device)
froipool = PrRoIPool((pool_h, pool_w), spatial_scale)
if _USING_PARROTS:
pass
# gradcheck(froipool, (x, rois), no_grads=[rois])
else:
gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2)
def _test_roipool_allclose(self, device, dtype=torch.float):
from mmcv.ops import prroi_pool
pool_h = 2
pool_w = 2
spatial_scale = 1.0
for case, output in zip(inputs, outputs):
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
np_output = np.array(output[0], dtype=np.float32)
np_input_grad = np.array(output[1], dtype=np.float32)
np_rois_grad = np.array(output[2], dtype=np.float32)
x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(
np_rois, dtype=dtype, device=device, requires_grad=True)
output = prroi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
assert np.allclose(x.grad.data.cpu().numpy(), np_input_grad, 1e-3)
assert np.allclose(rois.grad.data.cpu().numpy(), np_rois_grad,
1e-3)
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
])
def test_roipool_allclose_float(self, device):
self._test_roipool_allclose(device, dtype=torch.float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment