Unverified Commit c1da4a58 authored by Matheus Centa's avatar Matheus Centa Committed by GitHub
Browse files

Check target boxes input on generalized_rcnn.py (#2207)

* Check target boxes input on generalized_rcnn.py

* Fix target box validation in generalized_rcnn.py

* Add tests for input validation of detection models
parent 7aea80c9
...@@ -155,6 +155,24 @@ class ModelTester(TestCase): ...@@ -155,6 +155,24 @@ class ModelTester(TestCase):
# self.check_script(model, name) # self.check_script(model, name)
self.checkModule(model, name, ([x],)) self.checkModule(model, name, ([x],))
def _test_detection_model_validation(self, name):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
input_shape = (1, 3, 300, 300)
x = [torch.rand(input_shape)]
# validate that targets are present in training
self.assertRaises(ValueError, model, x)
# validate type
targets = [{'boxes': 0.}]
self.assertRaises(ValueError, model, x, targets=targets)
# validate boxes shape
for boxes in (torch.rand((4,)), torch.rand((1, 5))):
targets = [{'boxes': boxes}]
self.assertRaises(ValueError, model, x, targets=targets)
def _test_video_model(self, name): def _test_video_model(self, name):
# the default input shape is # the default input shape is
# bs * num_channels * clip_len * h *w # bs * num_channels * clip_len * h *w
...@@ -303,6 +321,11 @@ for model_name in get_available_detection_models(): ...@@ -303,6 +321,11 @@ for model_name in get_available_detection_models():
setattr(ModelTester, "test_" + model_name, do_test) setattr(ModelTester, "test_" + model_name, do_test)
def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name)
setattr(ModelTester, "test_" + model_name + "_validation", do_validation_test)
for model_name in get_available_video_models(): for model_name in get_available_video_models():
......
...@@ -57,6 +57,19 @@ class GeneralizedRCNN(nn.Module): ...@@ -57,6 +57,19 @@ class GeneralizedRCNN(nn.Module):
""" """
if self.training and targets is None: if self.training and targets is None:
raise ValueError("In training mode, targets should be passed") raise ValueError("In training mode, targets should be passed")
if self.training:
assert targets is not None
for target in targets:
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(
boxes.shape))
else:
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images: for img in images:
val = img.shape[-2:] val = img.shape[-2:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment