Unverified Commit 446eac61 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix torchhub due to numerical changes in torch.sum (#2361)

parent bb14c2bd
......@@ -13,7 +13,7 @@ def sum_of_model_parameters(model):
return s
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.99609375
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625
@unittest.skipIf('torchvision' in sys.modules,
......@@ -31,8 +31,9 @@ class TestHub(unittest.TestCase):
'resnet18',
pretrained=True,
progress=False)
self.assertEqual(sum_of_model_parameters(hub_model).item(),
SUM_OF_PRETRAINED_RESNET18_PARAMS)
self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
SUM_OF_PRETRAINED_RESNET18_PARAMS,
places=2)
def test_set_dir(self):
temp_dir = tempfile.gettempdir()
......@@ -42,8 +43,9 @@ class TestHub(unittest.TestCase):
'resnet18',
pretrained=True,
progress=False)
self.assertEqual(sum_of_model_parameters(hub_model).item(),
SUM_OF_PRETRAINED_RESNET18_PARAMS)
self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
SUM_OF_PRETRAINED_RESNET18_PARAMS,
places=2)
self.assertTrue(os.path.exists(temp_dir + '/pytorch_vision_master'))
shutil.rmtree(temp_dir + '/pytorch_vision_master')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment