test_load_model_zoo.py 5.89 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
import os
import os.path as osp
from unittest.mock import patch

6
import mmengine
7
import pytest
8
import torchvision
9
10
11
12

import mmcv
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
                                    ENV_XDG_CACHE_HOME, _get_mmcv_home,
13
14
15
                                    _load_checkpoint,
                                    get_deprecated_model_names,
                                    get_external_models)
16
from mmcv.utils import digit_version
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_set_mmcv_home():
    os.environ.pop(ENV_MMCV_HOME, None)
    mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
    os.environ[ENV_MMCV_HOME] = mmcv_home
    assert _get_mmcv_home() == mmcv_home


@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_default_mmcv_home():
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    assert _get_mmcv_home() == os.path.expanduser(
        os.path.join(DEFAULT_CACHE_DIR, 'mmcv'))
    model_urls = get_external_models()
34
    assert model_urls == mmengine.load(
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json'))


@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_get_external_models():
    os.environ.pop(ENV_MMCV_HOME, None)
    mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
    os.environ[ENV_MMCV_HOME] = mmcv_home
    ext_urls = get_external_models()
    assert ext_urls == {
        'train': 'https://localhost/train.pth',
        'test': 'test.pth',
        'val': 'val.pth',
        'train_empty': 'train.pth'
    }


52
53
54
55
56
57
58
59
60
61
62
63
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_get_deprecated_models():
    os.environ.pop(ENV_MMCV_HOME, None)
    mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
    os.environ[ENV_MMCV_HOME] = mmcv_home
    dep_urls = get_deprecated_model_names()
    assert dep_urls == {
        'train_old': 'train',
        'test_old': 'test',
    }


lizz's avatar
lizz committed
64
def load_from_http(url, map_location=None):
65
66
67
    return 'url:' + url


lizz's avatar
lizz committed
68
def load_url(url, map_location=None, model_dir=None):
69
70
71
    return load_from_http(url)


72
73
74
75
76
def load(filepath, map_location=None):
    return 'local:' + filepath


@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
77
@patch('mmcv.runner.checkpoint.load_from_http', load_from_http)
78
@patch('mmcv.runner.checkpoint.load_url', load_url)
79
80
81
@patch('torch.load', load)
def test_load_external_url():
    # test modelzoo://
82
83
84
85
86
87
88
89
    torchvision_version = torchvision.__version__
    if digit_version(torchvision_version) < digit_version('0.10.0a0'):
        assert (_load_checkpoint('modelzoo://resnet50') ==
                'url:https://download.pytorch.org/models/resnet50-19c8e'
                '357.pth')
        assert (_load_checkpoint('torchvision://resnet50') ==
                'url:https://download.pytorch.org/models/resnet50-19c8e'
                '357.pth')
90
    else:
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        assert (_load_checkpoint('modelzoo://resnet50') ==
                'url:https://download.pytorch.org/models/resnet50-0676b'
                'a61.pth')
        assert (_load_checkpoint('torchvision://resnet50') ==
                'url:https://download.pytorch.org/models/resnet50-0676b'
                'a61.pth')

    if digit_version(torchvision_version) >= digit_version('0.13.0a0'):
        # Test load new format torchvision models.
        assert (
            _load_checkpoint('torchvision://resnet50.imagenet1k_v1') ==
            'url:https://download.pytorch.org/models/resnet50-0676ba61.pth')

        assert (
            _load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') ==
            'url:https://download.pytorch.org/models/resnet50-0676ba61.pth')

        _load_checkpoint('torchvision://resnet50.default')
109
110
111
112
113
114
115

    # test open-mmlab:// with default MMCV_HOME
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    url = _load_checkpoint('open-mmlab://train')
    assert url == 'url:https://localhost/train.pth'

116
117
118
119
120
121
122
123
124
125
    # test open-mmlab:// with deprecated model name
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    with pytest.warns(
            Warning,
            match='open-mmlab://train_old is deprecated in favor of '
            'open-mmlab://train'):
        url = _load_checkpoint('open-mmlab://train_old')
        assert url == 'url:https://localhost/train.pth'

126
127
128
129
130
131
132
133
134
135
    # test openmmlab:// with deprecated model name
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    with pytest.warns(
            Warning,
            match='openmmlab://train_old is deprecated in favor of '
            'openmmlab://train'):
        url = _load_checkpoint('openmmlab://train_old')
        assert url == 'url:https://localhost/train.pth'

136
137
138
139
140
141
    # test open-mmlab:// with user-defined MMCV_HOME
    os.environ.pop(ENV_MMCV_HOME, None)
    mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
    os.environ[ENV_MMCV_HOME] = mmcv_home
    url = _load_checkpoint('open-mmlab://train')
    assert url == 'url:https://localhost/train.pth'
142
    with pytest.raises(FileNotFoundError, match='train.pth can not be found.'):
143
144
145
146
147
148
149
150
151
152
153
        _load_checkpoint('open-mmlab://train_empty')
    url = _load_checkpoint('open-mmlab://test')
    assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
    url = _load_checkpoint('open-mmlab://val')
    assert url == f'local:{osp.join(_get_mmcv_home(), "val.pth")}'

    # test http:// https://
    url = _load_checkpoint('http://localhost/train.pth')
    assert url == 'url:http://localhost/train.pth'

    # test local file
154
    with pytest.raises(FileNotFoundError, match='train.pth can not be found.'):
155
156
157
        _load_checkpoint('train.pth')
    url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
    assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'