Unverified Commit 2d6931ab authored by Anirudh's avatar Anirudh Committed by GitHub
Browse files

Port test_hub.py to pytest (#4038)

parent ec40ac3a
......@@ -242,6 +242,7 @@ jobs:
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off --editable .
pip install pytest
python test/test_hub.py
torch_onnx_test:
......
......@@ -242,6 +242,7 @@ jobs:
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off --editable .
pip install pytest
python test/test_hub.py
torch_onnx_test:
......
......@@ -3,7 +3,7 @@ import tempfile
import shutil
import os
import sys
import unittest
import pytest
def sum_of_model_parameters(model):
......@@ -16,9 +16,9 @@ def sum_of_model_parameters(model):
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625
@unittest.skipIf('torchvision' in sys.modules,
'TestHub must start without torchvision imported')
class TestHub(unittest.TestCase):
@pytest.mark.skipif('torchvision' in sys.modules,
reason='TestHub must start without torchvision imported')
class TestHub:
# Only run this check ONCE before all tests start.
# - If torchvision is imported before all tests start, e.g. we might find _C.so
# which doesn't exist in downloaded zip but in the installed wheel.
......@@ -31,9 +31,7 @@ class TestHub(unittest.TestCase):
'resnet18',
pretrained=True,
progress=False)
self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
SUM_OF_PRETRAINED_RESNET18_PARAMS,
places=2)
assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
def test_set_dir(self):
temp_dir = tempfile.gettempdir()
......@@ -43,16 +41,14 @@ class TestHub(unittest.TestCase):
'resnet18',
pretrained=True,
progress=False)
self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
SUM_OF_PRETRAINED_RESNET18_PARAMS,
places=2)
self.assertTrue(os.path.exists(temp_dir + '/pytorch_vision_master'))
assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
assert os.path.exists(temp_dir + '/pytorch_vision_master')
shutil.rmtree(temp_dir + '/pytorch_vision_master')
def test_list_entrypoints(self):
entry_lists = hub.list('pytorch/vision', force_reload=True)
self.assertIn('resnet18', entry_lists)
assert 'resnet18' in entry_lists
if __name__ == "__main__":
unittest.main()
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment