Unverified Commit 21049e90 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_models.py (#3879)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent b96d381c
...@@ -120,7 +120,7 @@ class ModelTester(TestCase): ...@@ -120,7 +120,7 @@ class ModelTester(TestCase):
# predictions match. # predictions match.
expected_file = self._get_expected_file(name) expected_file = self._get_expected_file(name)
expected = torch.load(expected_file) expected = torch.load(expected_file)
self.assertEqual(out.argmax(dim=1), expected.argmax(dim=1), prec=prec) torch.testing.assert_close(out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec)
return False # Partial validation performed return False # Partial validation performed
return True # Full validation performed return True # Full validation performed
...@@ -205,7 +205,8 @@ class ModelTester(TestCase): ...@@ -205,7 +205,8 @@ class ModelTester(TestCase):
# scores. # scores.
expected_file = self._get_expected_file(name) expected_file = self._get_expected_file(name)
expected = torch.load(expected_file) expected = torch.load(expected_file)
self.assertEqual(output[0]["scores"], expected[0]["scores"], prec=prec) torch.testing.assert_close(output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec,
check_device=False, check_dtype=False)
# Note: Fmassa proposed turning off NMS by adapting the threshold # Note: Fmassa proposed turning off NMS by adapting the threshold
# and then using the Hungarian algorithm as in DETR to find the # and then using the Hungarian algorithm as in DETR to find the
...@@ -301,10 +302,8 @@ class ModelTester(TestCase): ...@@ -301,10 +302,8 @@ class ModelTester(TestCase):
model2.eval() model2.eval()
out2 = model2(x) out2 = model2(x)
max_diff = (out1 - out2).abs().max()
self.assertTrue(num_params == num_grad) self.assertTrue(num_params == num_grad)
self.assertTrue(max_diff < 1e-5) torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)
def test_resnet_dilation(self): def test_resnet_dilation(self):
# TODO improve tests to also check that each layer has the right dimensionality # TODO improve tests to also check that each layer has the right dimensionality
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment