Unverified Commit c307db4b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_detection_utils.py (#3881)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 86d45414
...@@ -4,6 +4,7 @@ from torchvision.models.detection import _utils ...@@ -4,6 +4,7 @@ from torchvision.models.detection import _utils
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
import unittest import unittest
from torchvision.models.detection import backbone_utils from torchvision.models.detection import backbone_utils
from _assert_utils import assert_equal
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -55,8 +56,8 @@ class Tester(unittest.TestCase): ...@@ -55,8 +56,8 @@ class Tester(unittest.TestCase):
targets = [{'boxes': torch.rand(3, 4)}, {'boxes': torch.rand(2, 4)}] targets = [{'boxes': torch.rand(3, 4)}, {'boxes': torch.rand(2, 4)}]
targets_copy = copy.deepcopy(targets) targets_copy = copy.deepcopy(targets)
out = transform(image, targets) # noqa: F841 out = transform(image, targets) # noqa: F841
self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes'])) assert_equal(targets[0]['boxes'], targets_copy[0]['boxes'])
self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes'])) assert_equal(targets[1]['boxes'], targets_copy[1]['boxes'])
def test_not_float_normalize(self): def test_not_float_normalize(self):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment