test_hub.py 1.36 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import pytest
pc's avatar
pc committed
3
import torch
4
5
6
7
8
from torch.utils import model_zoo

from mmcv.utils import TORCH_VERSION, digit_version, load_url


pc's avatar
pc committed
9
10
@pytest.mark.skipif(
    torch.__version__ == 'parrots', reason='not necessary in parrots test')
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def test_load_url():
    url1 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.5.pth'
    url2 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.6.pth'

    # The 1.6 release of PyTorch switched torch.save to use a new zipfile-based
    # file format. It will cause RuntimeError when a checkpoint was saved in
    # torch >= 1.6.0 but loaded in torch < 1.7.0.
    # More details at https://github.com/open-mmlab/mmpose/issues/904
    if digit_version(TORCH_VERSION) < digit_version('1.7.0'):
        model_zoo.load_url(url1)
        with pytest.raises(RuntimeError):
            model_zoo.load_url(url2)
    else:
        # high version of PyTorch can load checkpoints from url, regardless
        # of which version they were saved in
        model_zoo.load_url(url1)
        model_zoo.load_url(url2)

    load_url(url1)
    # if a checkpoint was saved in torch >= 1.6.0 but loaded in torch < 1.5.0,
    # it will raise a RuntimeError
    if digit_version(TORCH_VERSION) < digit_version('1.5.0'):
        with pytest.raises(RuntimeError):
            load_url(url2)
    else:
        load_url(url2)