Unverified Commit 505cd695 authored by Guanheng George Zhang's avatar Guanheng George Zhang Committed by GitHub
Browse files

Check boxes shape in RoIPool / Align (#1968)



* add checkout/assert in roi_pool

* add checkout/assert in roi_align

* move check_roi_boxes_shape func to ops/_utils.py

* add tests

* fix CI

* fix CI
Co-authored-by: default avatarGuanheng Zhang <zhangguanheng@devfair0197.h2.fair>
parent ff81e2b1
......@@ -91,6 +91,22 @@ class RoIOpTester(OpTester):
self.assertTrue(gradcheck(func, (x,)))
self.assertTrue(gradcheck(script_func, (x,)))
def test_boxes_shape(self):
self._test_boxes_shape()
def _helper_boxes_shape(self, func):
# test boxes as Tensor[N, 5]
with self.assertRaises(AssertionError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
func(a, boxes, output_size=(2, 2))
# test boxes as List[Tensor[N, 4]]
with self.assertRaises(AssertionError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
ops.roi_pool(a, [boxes], output_size=(2, 2))
def fn(*args, **kwargs):
pass
......@@ -139,6 +155,9 @@ class RoIPoolTester(RoIOpTester, unittest.TestCase):
y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
return y
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_pool)
class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
......@@ -183,6 +202,9 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
y[roi_idx, c_out, i, j] = t / area
return y
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.ps_roi_pool)
def bilinear_interpolate(data, y, x, snap_border=False):
height, width = data.shape
......@@ -266,6 +288,9 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
out_data[r, channel, i, j] = val
return out_data
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
......@@ -317,6 +342,9 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
out_data[r, c_out, i, j] = val
return out_data
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.ps_roi_align)
class NMSTester(unittest.TestCase):
def reference_nms(self, boxes, scores, iou_threshold):
......
......@@ -24,3 +24,15 @@ def convert_boxes_to_roi_format(boxes):
ids = _cat(temp, dim=0)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def check_roi_boxes_shape(boxes):
if isinstance(boxes, list):
for _tensor in boxes:
assert _tensor.size(1) == 4, \
'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]'
elif isinstance(boxes, torch.Tensor):
assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]'
else:
assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]'
return
......@@ -4,7 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from ._utils import convert_boxes_to_roi_format
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
......@@ -33,6 +33,7 @@ def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
......
......@@ -4,7 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from ._utils import convert_boxes_to_roi_format
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
......@@ -28,6 +28,7 @@ def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
......
......@@ -4,7 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
......@@ -35,6 +35,7 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, a
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
......
......@@ -4,7 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def roi_pool(input, boxes, output_size, spatial_scale=1.0):
......@@ -27,6 +27,7 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0):
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment