Unverified Commit e1e975f9 authored by AhnDW's avatar AhnDW Committed by GitHub
Browse files

`aligned` flag in ROIAlign (#1908)

* Aligned flag in the interfaces

* Aligned flag in the impl, and remove unused comments

* Handling empty bin in forward

* Remove raise error in roi_width

* Aligned flag in the Testcodes
parent b2e95657
......@@ -217,9 +217,9 @@ def bilinear_interpolate(data, y, x, snap_border=False):
class RoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
return ops.RoIAlign((pool_h, pool_w), spatial_scale=spatial_scale,
sampling_ratio=sampling_ratio)(x, rois)
sampling_ratio=sampling_ratio, aligned=aligned)(x, rois)
def get_script_fn(self, rois, pool_size):
@torch.jit.script
......@@ -228,16 +228,18 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
return ops.roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False,
device=None, dtype=torch.float64):
if device is None:
device = torch.device("cpu")
n_channels = in_data.size(1)
out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
offset = 0.5 if aligned else 0.
for r, roi in enumerate(rois):
batch_idx = int(roi[0])
j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale for x in roi[1:])
j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
roi_h = i_end - i_begin
roi_w = j_end - j_begin
......
......@@ -14,7 +14,8 @@ at::Tensor ROIAlign_forward(
// scaled to this.
const int64_t pooled_height, // The height of the pooled feature map.
const int64_t pooled_width, // The width of the pooled feature
const int64_t sampling_ratio) // The number of points to sample in each bin
const int64_t sampling_ratio, // The number of points to sample in each bin
const bool aligned) // The flag for pixel shift
// along each axis.
{
if (input.type().is_cuda()) {
......@@ -25,13 +26,14 @@ at::Tensor ROIAlign_forward(
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
sampling_ratio,
aligned);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIAlign_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned);
}
at::Tensor ROIAlign_backward(
......@@ -44,7 +46,8 @@ at::Tensor ROIAlign_backward(
const int channels,
const int height,
const int width,
const int sampling_ratio) {
const int sampling_ratio,
const bool aligned) {
if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
return ROIAlign_backward_cuda(
......@@ -57,7 +60,8 @@ at::Tensor ROIAlign_backward(
channels,
height,
width,
sampling_ratio);
sampling_ratio,
aligned);
#else
AT_ERROR("Not compiled with GPU support");
#endif
......@@ -72,7 +76,8 @@ at::Tensor ROIAlign_backward(
channels,
height,
width,
sampling_ratio);
sampling_ratio,
aligned);
}
using namespace at;
......@@ -90,11 +95,13 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
const int64_t sampling_ratio,
const bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
auto result = ROIAlign_forward(
......@@ -103,7 +110,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
sampling_ratio,
aligned);
return {result};
}
......@@ -124,9 +132,10 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt());
ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool());
return {
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
grad_in, Variable(), Variable(), Variable(), Variable(), Variable(), Variable()};
}
};
......@@ -136,12 +145,14 @@ Tensor roi_align(
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
const int64_t sampling_ratio,
const bool aligned) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio)[0];
sampling_ratio,
aligned)[0];
}
......@@ -121,6 +121,7 @@ void ROIAlignForward(
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
const T* rois,
T* output) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
......@@ -134,18 +135,16 @@ void ROIAlignForward(
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
// T roi_start_w = round(offset_rois[0] * spatial_scale);
// T roi_start_h = round(offset_rois[1] * spatial_scale);
// T roi_end_w = round(offset_rois[2] * spatial_scale);
// T roi_end_h = round(offset_rois[3] * spatial_scale);
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
......@@ -157,7 +156,8 @@ void ROIAlignForward(
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
// When the grid is empty, output zeros.
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
// we want to precalculate indeces and weights shared by all chanels,
// this is the key point of optimiation
......@@ -285,6 +285,7 @@ void ROIAlignBackward(
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
T* grad_input,
const T* rois,
const int n_stride,
......@@ -302,14 +303,16 @@ void ROIAlignBackward(
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
......@@ -381,7 +384,8 @@ at::Tensor ROIAlign_forward_cpu(
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
const int sampling_ratio,
const bool aligned) {
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
......@@ -414,6 +418,7 @@ at::Tensor ROIAlign_forward_cpu(
pooled_height,
pooled_width,
sampling_ratio,
aligned,
rois.contiguous().data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
});
......@@ -430,7 +435,8 @@ at::Tensor ROIAlign_backward_cpu(
const int channels,
const int height,
const int width,
const int sampling_ratio) {
const int sampling_ratio,
const bool aligned) {
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
......@@ -464,6 +470,7 @@ at::Tensor ROIAlign_backward_cpu(
pooled_height,
pooled_width,
sampling_ratio,
aligned,
grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>(),
n_stride,
......
......@@ -26,7 +26,8 @@ at::Tensor ROIAlign_forward_cpu(
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
const int sampling_ratio,
const bool aligned);
at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
......@@ -38,7 +39,8 @@ at::Tensor ROIAlign_backward_cpu(
const int channels,
const int height,
const int width,
const int sampling_ratio);
const int sampling_ratio,
const bool aligned);
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input,
......
......@@ -71,6 +71,7 @@ __global__ void RoIAlignForward(
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
const T* rois,
T* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
......@@ -84,14 +85,16 @@ __global__ void RoIAlignForward(
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
......@@ -106,7 +109,8 @@ __global__ void RoIAlignForward(
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
// When the grid is empty, output zeros.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
......@@ -201,6 +205,7 @@ __global__ void RoIAlignBackward(
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
T* grad_input,
const T* rois,
const int n_stride,
......@@ -218,14 +223,16 @@ __global__ void RoIAlignBackward(
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
......@@ -303,7 +310,8 @@ at::Tensor ROIAlign_forward_cuda(
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
const int sampling_ratio,
const bool aligned) {
AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
......@@ -348,6 +356,7 @@ at::Tensor ROIAlign_forward_cuda(
pooled_height,
pooled_width,
sampling_ratio,
aligned,
rois.contiguous().data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
});
......@@ -365,7 +374,8 @@ at::Tensor ROIAlign_backward_cuda(
const int channels,
const int height,
const int width,
const int sampling_ratio) {
const int sampling_ratio,
const bool aligned) {
AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
......@@ -410,6 +420,7 @@ at::Tensor ROIAlign_backward_cuda(
pooled_height,
pooled_width,
sampling_ratio,
aligned,
grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>(),
n_stride,
......
......@@ -8,7 +8,8 @@ at::Tensor ROIAlign_forward_cuda(
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
const int sampling_ratio,
const bool aligned);
at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad,
......@@ -20,7 +21,8 @@ at::Tensor ROIAlign_backward_cuda(
const int channels,
const int height,
const int width,
const int sampling_ratio);
const int sampling_ratio,
const bool aligned);
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
const at::Tensor& input,
......
......@@ -42,7 +42,7 @@ int64_t _cuda_version() {
static auto registry =
torch::RegisterOperators()
.op("torchvision::nms", &nms)
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor",
&roi_align)
.op("torchvision::roi_pool", &roi_pool)
.op("torchvision::_new_empty_tensor_op", &new_empty_tensor)
......
......@@ -7,8 +7,8 @@ from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
# type: (Tensor, Tensor, BroadcastingList2[int], float, int) -> Tensor
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
# type: (Tensor, Tensor, BroadcastingList2[int], float, int, bool) -> Tensor
"""
Performs Region of Interest (RoI) Align operator described in Mask R-CNN
......@@ -28,6 +28,9 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
then exactly sampling_ratio x sampling_ratio grid points are used. If
<= 0, then an adaptive number of grid points are used (computed as
ceil(roi_width / pooled_w), and likewise for height). Default: -1
aligned (bool): If False, use the legacy implementation.
If True, pixel shift it by -0.5 for align more perfectly about two neighboring pixel indices.
This version in Detectron2
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
......@@ -38,26 +41,28 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
rois = convert_boxes_to_roi_format(rois)
return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
output_size[0], output_size[1],
sampling_ratio)
sampling_ratio, aligned)
class RoIAlign(nn.Module):
"""
See roi_align
"""
def __init__(self, output_size, spatial_scale, sampling_ratio):
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=False):
super(RoIAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
def forward(self, input, rois):
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
def __repr__(self):
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ', sampling_ratio=' + str(self.sampling_ratio)
tmpstr += ', aligned=' + str(self.aligned)
tmpstr += ')'
return tmpstr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment