Unverified Commit 7dc8a690 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Port test_models_detection_negative_samples.py to pytest (#4045)

parent e4eded48
...@@ -6,10 +6,11 @@ from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionPro ...@@ -6,10 +6,11 @@ from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionPro
from torchvision.models.detection.roi_heads import RoIHeads from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
import unittest import pytest
from _assert_utils import assert_equal
class Tester(unittest.TestCase): class TestModelsDetectionNegativeSamples:
def _make_empty_sample(self, add_masks=False, add_keypoints=False): def _make_empty_sample(self, add_masks=False, add_keypoints=False):
images = [torch.rand((3, 100, 100), dtype=torch.float32)] images = [torch.rand((3, 100, 100), dtype=torch.float32)]
...@@ -48,13 +49,13 @@ class Tester(unittest.TestCase): ...@@ -48,13 +49,13 @@ class Tester(unittest.TestCase):
labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets) labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)
self.assertEqual(labels[0].sum(), 0) assert labels[0].sum() == 0
self.assertEqual(labels[0].shape, torch.Size([anchors[0].shape[0]])) assert labels[0].shape == torch.Size([anchors[0].shape[0]])
self.assertEqual(labels[0].dtype, torch.float32) assert labels[0].dtype == torch.float32
self.assertEqual(matched_gt_boxes[0].sum(), 0) assert matched_gt_boxes[0].sum() == 0
self.assertEqual(matched_gt_boxes[0].shape, anchors[0].shape) assert matched_gt_boxes[0].shape == anchors[0].shape
self.assertEqual(matched_gt_boxes[0].dtype, torch.float32) assert matched_gt_boxes[0].dtype == torch.float32
def test_assign_targets_to_proposals(self): def test_assign_targets_to_proposals(self):
...@@ -88,25 +89,28 @@ class Tester(unittest.TestCase): ...@@ -88,25 +89,28 @@ class Tester(unittest.TestCase):
matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
self.assertEqual(matched_idxs[0].sum(), 0) assert matched_idxs[0].sum() == 0
self.assertEqual(matched_idxs[0].shape, torch.Size([proposals[0].shape[0]])) assert matched_idxs[0].shape == torch.Size([proposals[0].shape[0]])
self.assertEqual(matched_idxs[0].dtype, torch.int64) assert matched_idxs[0].dtype == torch.int64
self.assertEqual(labels[0].sum(), 0) assert labels[0].sum() == 0
self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]])) assert labels[0].shape == torch.Size([proposals[0].shape[0]])
self.assertEqual(labels[0].dtype, torch.int64) assert labels[0].dtype == torch.int64
def test_forward_negative_sample_frcnn(self): @pytest.mark.parametrize('name', [
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn", "fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn"]: "fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
])
def test_forward_negative_sample_frcnn(self, name):
model = torchvision.models.detection.__dict__[name]( model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100) num_classes=2, min_size=100, max_size=100)
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
def test_forward_negative_sample_mrcnn(self): def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn( model = torchvision.models.detection.maskrcnn_resnet50_fpn(
...@@ -115,9 +119,9 @@ class Tester(unittest.TestCase): ...@@ -115,9 +119,9 @@ class Tester(unittest.TestCase):
images, targets = self._make_empty_sample(add_masks=True) images, targets = self._make_empty_sample(add_masks=True)
loss_dict = model(images, targets) loss_dict = model(images, targets)
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_mask"], torch.tensor(0.)) assert_equal(loss_dict["loss_mask"], torch.tensor(0.))
def test_forward_negative_sample_krcnn(self): def test_forward_negative_sample_krcnn(self):
model = torchvision.models.detection.keypointrcnn_resnet50_fpn( model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
...@@ -126,9 +130,9 @@ class Tester(unittest.TestCase): ...@@ -126,9 +130,9 @@ class Tester(unittest.TestCase):
images, targets = self._make_empty_sample(add_keypoints=True) images, targets = self._make_empty_sample(add_keypoints=True)
loss_dict = model(images, targets) loss_dict = model(images, targets)
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.)) assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.))
def test_forward_negative_sample_retinanet(self): def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn( model = torchvision.models.detection.retinanet_resnet50_fpn(
...@@ -137,7 +141,7 @@ class Tester(unittest.TestCase): ...@@ -137,7 +141,7 @@ class Tester(unittest.TestCase):
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) assert_equal(loss_dict["bbox_regression"], torch.tensor(0.))
def test_forward_negative_sample_ssd(self): def test_forward_negative_sample_ssd(self):
model = torchvision.models.detection.ssd300_vgg16( model = torchvision.models.detection.ssd300_vgg16(
...@@ -146,8 +150,8 @@ class Tester(unittest.TestCase): ...@@ -146,8 +150,8 @@ class Tester(unittest.TestCase):
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) assert_equal(loss_dict["bbox_regression"], torch.tensor(0.))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment