test_hub.py 1.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch.hub as hub
import tempfile
import shutil
import os
import sys
import unittest


def sum_of_model_parameters(model):
    s = 0
    for p in model.parameters():
        s += p.sum()
    return s


16
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


@unittest.skipIf('torchvision' in sys.modules,
                 'TestHub must start without torchvision imported')
class TestHub(unittest.TestCase):
    # Only run this check ONCE before all tests start.
    # - If torchvision is imported before all tests start, e.g. we might find _C.so
    #   which doesn't exist in downloaded zip but in the installed wheel.
    # - After the first test is run, torchvision is already in sys.modules due to
    #   Python cache as we run all hub tests in the same python process.

    def test_load_from_github(self):
        hub_model = hub.load(
            'pytorch/vision',
            'resnet18',
            pretrained=True,
            progress=False)
34
35
36
        self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
                               SUM_OF_PRETRAINED_RESNET18_PARAMS,
                               places=2)
37
38
39
40
41
42
43
44
45

    def test_set_dir(self):
        temp_dir = tempfile.gettempdir()
        hub.set_dir(temp_dir)
        hub_model = hub.load(
            'pytorch/vision',
            'resnet18',
            pretrained=True,
            progress=False)
46
47
48
        self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
                               SUM_OF_PRETRAINED_RESNET18_PARAMS,
                               places=2)
49
        self.assertTrue(os.path.exists(temp_dir + '/pytorch_vision_master'))
50
51
52
53
54
55
56
57
58
        shutil.rmtree(temp_dir + '/pytorch_vision_master')

    def test_list_entrypoints(self):
        entry_lists = hub.list('pytorch/vision', force_reload=True)
        self.assertIn('resnet18', entry_lists)


if __name__ == "__main__":
    unittest.main()