Unverified Commit e4eded48 authored by Vivek Kumar's avatar Vivek Kumar Committed by GitHub
Browse files

Port test_models_detection_anchor_utils.py to pytest (#4046)

parent 0e7ae64b
......@@ -3,6 +3,7 @@ from common_utils import TestCase
from _assert_utils import assert_equal
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
from torchvision.models.detection.image_list import ImageList
import pytest
class Tester(TestCase):
......@@ -13,7 +14,7 @@ class Tester(TestCase):
image1 = torch.randn(3, 800, 800)
image_list = ImageList(image1, [(800, 800)])
feature_maps = [torch.randn(1, 50)]
self.assertRaises(ValueError, anc, image_list, feature_maps)
pytest.raises(ValueError, anc, image_list, feature_maps)
def _init_test_anchor_generator(self):
anchor_sizes = ((10,),)
......@@ -59,10 +60,10 @@ class Tester(TestCase):
[0., 5., 10., 15.],
[5., 5., 15., 15.]])
self.assertEqual(num_anchors_estimated, 9)
self.assertEqual(len(anchors), 2)
self.assertEqual(tuple(anchors[0].shape), (9, 4))
self.assertEqual(tuple(anchors[1].shape), (9, 4))
assert num_anchors_estimated == 9
assert len(anchors) == 2
assert tuple(anchors[0].shape) == (9, 4)
assert tuple(anchors[1].shape) == (9, 4)
assert_equal(anchors[0], anchors_output)
assert_equal(anchors[1], anchors_output)
......@@ -83,8 +84,8 @@ class Tester(TestCase):
[6.7045, 5.9090, 8.2955, 9.0910]
])
self.assertEqual(len(dboxes), 2)
self.assertEqual(tuple(dboxes[0].shape), (4, 4))
self.assertEqual(tuple(dboxes[1].shape), (4, 4))
assert len(dboxes) == 2
assert tuple(dboxes[0].shape) == (4, 4)
assert tuple(dboxes[1].shape) == (4, 4)
torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8)
torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment