Unverified Commit e4eded48 authored by Vivek Kumar's avatar Vivek Kumar Committed by GitHub
Browse files

Port test_models_detection_anchor_utils.py to pytest (#4046)

parent 0e7ae64b
...@@ -3,6 +3,7 @@ from common_utils import TestCase ...@@ -3,6 +3,7 @@ from common_utils import TestCase
from _assert_utils import assert_equal from _assert_utils import assert_equal
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
from torchvision.models.detection.image_list import ImageList from torchvision.models.detection.image_list import ImageList
import pytest
class Tester(TestCase): class Tester(TestCase):
...@@ -13,7 +14,7 @@ class Tester(TestCase): ...@@ -13,7 +14,7 @@ class Tester(TestCase):
image1 = torch.randn(3, 800, 800) image1 = torch.randn(3, 800, 800)
image_list = ImageList(image1, [(800, 800)]) image_list = ImageList(image1, [(800, 800)])
feature_maps = [torch.randn(1, 50)] feature_maps = [torch.randn(1, 50)]
self.assertRaises(ValueError, anc, image_list, feature_maps) pytest.raises(ValueError, anc, image_list, feature_maps)
def _init_test_anchor_generator(self): def _init_test_anchor_generator(self):
anchor_sizes = ((10,),) anchor_sizes = ((10,),)
...@@ -59,10 +60,10 @@ class Tester(TestCase): ...@@ -59,10 +60,10 @@ class Tester(TestCase):
[0., 5., 10., 15.], [0., 5., 10., 15.],
[5., 5., 15., 15.]]) [5., 5., 15., 15.]])
self.assertEqual(num_anchors_estimated, 9) assert num_anchors_estimated == 9
self.assertEqual(len(anchors), 2) assert len(anchors) == 2
self.assertEqual(tuple(anchors[0].shape), (9, 4)) assert tuple(anchors[0].shape) == (9, 4)
self.assertEqual(tuple(anchors[1].shape), (9, 4)) assert tuple(anchors[1].shape) == (9, 4)
assert_equal(anchors[0], anchors_output) assert_equal(anchors[0], anchors_output)
assert_equal(anchors[1], anchors_output) assert_equal(anchors[1], anchors_output)
...@@ -83,8 +84,8 @@ class Tester(TestCase): ...@@ -83,8 +84,8 @@ class Tester(TestCase):
[6.7045, 5.9090, 8.2955, 9.0910] [6.7045, 5.9090, 8.2955, 9.0910]
]) ])
self.assertEqual(len(dboxes), 2) assert len(dboxes) == 2
self.assertEqual(tuple(dboxes[0].shape), (4, 4)) assert tuple(dboxes[0].shape) == (4, 4)
self.assertEqual(tuple(dboxes[1].shape), (4, 4)) assert tuple(dboxes[1].shape) == (4, 4)
torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8) torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8)
torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8) torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment