Unverified Commit a09d129c authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

fix: Restored support of tuple of Tensors for region pooling ops (#2199)

* feat: Restored support of tuple of Tensors for roi_align & roi_pool

* test: Added unittest for Tensor sequence support by region pooling

* test: Fixed typo in unittest

* test: Fixed data type

* test: Fixed roi pooling tensor unittest

* test: Fixed box format conversion
parent caf15cd0
...@@ -548,5 +548,30 @@ class FrozenBNTester(unittest.TestCase): ...@@ -548,5 +548,30 @@ class FrozenBNTester(unittest.TestCase):
self.assertEqual(t.__repr__(), expected_string) self.assertEqual(t.__repr__(), expected_string)
class BoxConversionTester(unittest.TestCase):
@staticmethod
def _get_box_sequences():
# Define here the argument type of `boxes` supported by region pooling operations
box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
box_list = [torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
torch.tensor([[0, 0, 100, 100]], dtype=torch.float)]
box_tuple = tuple(box_list)
return box_tensor, box_list, box_tuple
def test_check_roi_boxes_shape(self):
# Ensure common sequences of tensors are supported
for box_sequence in self._get_box_sequences():
self.assertIsNone(ops._utils.check_roi_boxes_shape(box_sequence))
def test_convert_boxes_to_roi_format(self):
# Ensure common sequences of tensors yield the same result
ref_tensor = None
for box_sequence in self._get_box_sequences():
if ref_tensor is None:
ref_tensor = box_sequence
else:
self.assertTrue(torch.equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -27,7 +27,7 @@ def convert_boxes_to_roi_format(boxes): ...@@ -27,7 +27,7 @@ def convert_boxes_to_roi_format(boxes):
def check_roi_boxes_shape(boxes): def check_roi_boxes_shape(boxes):
if isinstance(boxes, list): if isinstance(boxes, (list, tuple)):
for _tensor in boxes: for _tensor in boxes:
assert _tensor.size(1) == 4, \ assert _tensor.size(1) == 4, \
'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment