Unverified Commit 51d694e1 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add tests for negative samples for Mask R-CNN and Keypoint R-CNN (#2069)

* Add tests for negative samples for Mask R-CNN and Keypoint R-CNN

* Fix lint
parent 57c789f8
...@@ -11,16 +11,27 @@ import unittest ...@@ -11,16 +11,27 @@ import unittest
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def test_targets_to_anchors(self): def _make_empty_sample(self, add_masks=False, add_keypoints=False):
images = [torch.rand((3, 100, 100), dtype=torch.float32)]
boxes = torch.zeros((0, 4), dtype=torch.float32) boxes = torch.zeros((0, 4), dtype=torch.float32)
negative_target = {"boxes": boxes, negative_target = {"boxes": boxes,
"labels": torch.zeros((1, 1), dtype=torch.int64), "labels": torch.zeros(0, dtype=torch.int64),
"image_id": 4, "image_id": 4,
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
"iscrowd": torch.zeros((0,), dtype=torch.int64)} "iscrowd": torch.zeros((0,), dtype=torch.int64)}
anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)] if add_masks:
negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8)
if add_keypoints:
negative_target["keypoints"] = torch.zeros(17, 0, 3, dtype=torch.float32)
targets = [negative_target] targets = [negative_target]
return images, targets
def test_targets_to_anchors(self):
_, targets = self._make_empty_sample()
anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)]
anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
...@@ -85,25 +96,37 @@ class Tester(unittest.TestCase): ...@@ -85,25 +96,37 @@ class Tester(unittest.TestCase):
self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]])) self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]]))
self.assertEqual(labels[0].dtype, torch.int64) self.assertEqual(labels[0].dtype, torch.int64)
def test_forward_negative_sample(self): def test_forward_negative_sample_frcnn(self):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
in_features = model.roi_heads.box_predictor.cls_score.in_features num_classes=2, min_size=100, max_size=100)
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
images = [torch.rand((3, 100, 100), dtype=torch.float32)] images, targets = self._make_empty_sample()
boxes = torch.zeros((0, 4), dtype=torch.float32) loss_dict = model(images, targets)
negative_target = {"boxes": boxes,
"labels": torch.zeros((1, 1), dtype=torch.int64),
"image_id": 4,
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
"iscrowd": torch.zeros((0,), dtype=torch.int64)}
targets = [negative_target] self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)
images, targets = self._make_empty_sample(add_masks=True)
loss_dict = model(images, targets)
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_mask"], torch.tensor(0.))
def test_forward_negative_sample_krcnn(self):
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)
images, targets = self._make_empty_sample(add_keypoints=True)
loss_dict = model(images, targets) loss_dict = model(images, targets)
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.)) self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -113,7 +113,7 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne ...@@ -113,7 +113,7 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne
) )
output_shape = _output_size(2, input, size, scale_factor) output_shape = _output_size(2, input, size, scale_factor)
output_shape = input.shape[:-2] + output_shape output_shape = list(input.shape[:-2]) + output_shape
return _new_empty_tensor(input, output_shape) return _new_empty_tensor(input, output_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment