Unverified Commit 505cd695 authored by Guanheng George Zhang's avatar Guanheng George Zhang Committed by GitHub
Browse files

Check boxes shape in RoIPool / Align (#1968)



* add checkout/assert in roi_pool

* add checkout/assert in roi_align

* move check_roi_boxes_shape func to ops/_utils.py

* add tests

* fix CI

* fix CI
Co-authored-by: default avatarGuanheng Zhang <zhangguanheng@devfair0197.h2.fair>
parent ff81e2b1
...@@ -91,6 +91,22 @@ class RoIOpTester(OpTester): ...@@ -91,6 +91,22 @@ class RoIOpTester(OpTester):
self.assertTrue(gradcheck(func, (x,))) self.assertTrue(gradcheck(func, (x,)))
self.assertTrue(gradcheck(script_func, (x,))) self.assertTrue(gradcheck(script_func, (x,)))
def test_boxes_shape(self):
self._test_boxes_shape()
def _helper_boxes_shape(self, func):
# test boxes as Tensor[N, 5]
with self.assertRaises(AssertionError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
func(a, boxes, output_size=(2, 2))
# test boxes as List[Tensor[N, 4]]
with self.assertRaises(AssertionError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
ops.roi_pool(a, [boxes], output_size=(2, 2))
def fn(*args, **kwargs): def fn(*args, **kwargs):
pass pass
...@@ -139,6 +155,9 @@ class RoIPoolTester(RoIOpTester, unittest.TestCase): ...@@ -139,6 +155,9 @@ class RoIPoolTester(RoIOpTester, unittest.TestCase):
y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0] y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
return y return y
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_pool)
class PSRoIPoolTester(RoIOpTester, unittest.TestCase): class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
...@@ -183,6 +202,9 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase): ...@@ -183,6 +202,9 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
y[roi_idx, c_out, i, j] = t / area y[roi_idx, c_out, i, j] = t / area
return y return y
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.ps_roi_pool)
def bilinear_interpolate(data, y, x, snap_border=False): def bilinear_interpolate(data, y, x, snap_border=False):
height, width = data.shape height, width = data.shape
...@@ -266,6 +288,9 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase): ...@@ -266,6 +288,9 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
out_data[r, channel, i, j] = val out_data[r, channel, i, j] = val
return out_data return out_data
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)
class PSRoIAlignTester(RoIOpTester, unittest.TestCase): class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
...@@ -317,6 +342,9 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase): ...@@ -317,6 +342,9 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
out_data[r, c_out, i, j] = val out_data[r, c_out, i, j] = val
return out_data return out_data
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.ps_roi_align)
class NMSTester(unittest.TestCase): class NMSTester(unittest.TestCase):
def reference_nms(self, boxes, scores, iou_threshold): def reference_nms(self, boxes, scores, iou_threshold):
......
...@@ -24,3 +24,15 @@ def convert_boxes_to_roi_format(boxes): ...@@ -24,3 +24,15 @@ def convert_boxes_to_roi_format(boxes):
ids = _cat(temp, dim=0) ids = _cat(temp, dim=0)
rois = torch.cat([ids, concat_boxes], dim=1) rois = torch.cat([ids, concat_boxes], dim=1)
return rois return rois
def check_roi_boxes_shape(boxes):
if isinstance(boxes, list):
for _tensor in boxes:
assert _tensor.size(1) == 4, \
'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]'
elif isinstance(boxes, torch.Tensor):
assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]'
else:
assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]'
return
...@@ -4,7 +4,7 @@ from torch import nn, Tensor ...@@ -4,7 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import List from torch.jit.annotations import List
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
...@@ -33,6 +33,7 @@ def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1 ...@@ -33,6 +33,7 @@ def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1
Returns: Returns:
output (Tensor[K, C, output_size[0], output_size[1]]) output (Tensor[K, C, output_size[0], output_size[1]])
""" """
check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
......
...@@ -4,7 +4,7 @@ from torch import nn, Tensor ...@@ -4,7 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import List from torch.jit.annotations import List
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0): def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
...@@ -28,6 +28,7 @@ def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0): ...@@ -28,6 +28,7 @@ def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
Returns: Returns:
output (Tensor[K, C, output_size[0], output_size[1]]) output (Tensor[K, C, output_size[0], output_size[1]])
""" """
check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
......
...@@ -4,7 +4,7 @@ from torch import nn, Tensor ...@@ -4,7 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2 from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False): def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
...@@ -35,6 +35,7 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, a ...@@ -35,6 +35,7 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, a
Returns: Returns:
output (Tensor[K, C, output_size[0], output_size[1]]) output (Tensor[K, C, output_size[0], output_size[1]])
""" """
check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
......
...@@ -4,7 +4,7 @@ from torch import nn, Tensor ...@@ -4,7 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2 from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def roi_pool(input, boxes, output_size, spatial_scale=1.0): def roi_pool(input, boxes, output_size, spatial_scale=1.0):
...@@ -27,6 +27,7 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0): ...@@ -27,6 +27,7 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0):
Returns: Returns:
output (Tensor[K, C, output_size[0], output_size[1]]) output (Tensor[K, C, output_size[0], output_size[1]])
""" """
check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment