Unverified Commit 446eac61 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix torchhub due to numerical changes in torch.sum (#2361)

parent bb14c2bd
...@@ -13,7 +13,7 @@ def sum_of_model_parameters(model): ...@@ -13,7 +13,7 @@ def sum_of_model_parameters(model):
return s return s
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.99609375 SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625
@unittest.skipIf('torchvision' in sys.modules, @unittest.skipIf('torchvision' in sys.modules,
...@@ -31,8 +31,9 @@ class TestHub(unittest.TestCase): ...@@ -31,8 +31,9 @@ class TestHub(unittest.TestCase):
'resnet18', 'resnet18',
pretrained=True, pretrained=True,
progress=False) progress=False)
self.assertEqual(sum_of_model_parameters(hub_model).item(), self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
SUM_OF_PRETRAINED_RESNET18_PARAMS) SUM_OF_PRETRAINED_RESNET18_PARAMS,
places=2)
def test_set_dir(self): def test_set_dir(self):
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
...@@ -42,8 +43,9 @@ class TestHub(unittest.TestCase): ...@@ -42,8 +43,9 @@ class TestHub(unittest.TestCase):
'resnet18', 'resnet18',
pretrained=True, pretrained=True,
progress=False) progress=False)
self.assertEqual(sum_of_model_parameters(hub_model).item(), self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
SUM_OF_PRETRAINED_RESNET18_PARAMS) SUM_OF_PRETRAINED_RESNET18_PARAMS,
places=2)
self.assertTrue(os.path.exists(temp_dir + '/pytorch_vision_master')) self.assertTrue(os.path.exists(temp_dir + '/pytorch_vision_master'))
shutil.rmtree(temp_dir + '/pytorch_vision_master') shutil.rmtree(temp_dir + '/pytorch_vision_master')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment