Unverified Commit c8064cdb authored by NVS Abhilash's avatar NVS Abhilash Committed by GitHub
Browse files

check for degenerate boxes (fixes #2240) (#2258)

parent 5ba57eae
...@@ -158,7 +158,7 @@ class ModelTester(TestCase): ...@@ -158,7 +158,7 @@ class ModelTester(TestCase):
def _test_detection_model_validation(self, name): def _test_detection_model_validation(self, name):
set_rng_seed(0) set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
input_shape = (1, 3, 300, 300) input_shape = (3, 300, 300)
x = [torch.rand(input_shape)] x = [torch.rand(input_shape)]
# validate that targets are present in training # validate that targets are present in training
...@@ -173,6 +173,11 @@ class ModelTester(TestCase): ...@@ -173,6 +173,11 @@ class ModelTester(TestCase):
targets = [{'boxes': boxes}] targets = [{'boxes': boxes}]
self.assertRaises(ValueError, model, x, targets=targets) self.assertRaises(ValueError, model, x, targets=targets)
# validate that no degenerate boxes are present
boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]])
targets = [{'boxes': boxes}]
self.assertRaises(ValueError, model, x, targets=targets)
def _test_video_model(self, name): def _test_video_model(self, name):
# the default input shape is # the default input shape is
# bs * num_channels * clip_len * h *w # bs * num_channels * clip_len * h *w
......
...@@ -77,6 +77,21 @@ class GeneralizedRCNN(nn.Module): ...@@ -77,6 +77,21 @@ class GeneralizedRCNN(nn.Module):
original_image_sizes.append((val[0], val[1])) original_image_sizes.append((val[0], val[1]))
images, targets = self.transform(images, targets) images, targets = self.transform(images, targets)
# Check for degenerate boxes
# TODO: Move this to a function
if targets is not None:
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
# print the first degenrate box
bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invaid box {} for target at index {}."
.format(degen_bb, target_idx))
features = self.backbone(images.tensors) features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor): if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)]) features = OrderedDict([('0', features)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment