test_hub.py 1.36 KB
Newer Older
limm's avatar
limm committed
1
# Copyright (c) OpenMMLab. All rights reserved.
limm's avatar
limm committed
2
import pytest
limm's avatar
limm committed
3
import torch
limm's avatar
limm committed
4
5
6
7
8
from torch.utils import model_zoo

from mmcv.utils import TORCH_VERSION, digit_version, load_url


limm's avatar
limm committed
9
10
@pytest.mark.skipif(
    torch.__version__ == 'parrots', reason='not necessary in parrots test')
limm's avatar
limm committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def test_load_url():
    url1 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.5.pth'
    url2 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.6.pth'

    # The 1.6 release of PyTorch switched torch.save to use a new zipfile-based
    # file format. It will cause RuntimeError when a checkpoint was saved in
    # torch >= 1.6.0 but loaded in torch < 1.7.0.
    # More details at https://github.com/open-mmlab/mmpose/issues/904
    if digit_version(TORCH_VERSION) < digit_version('1.7.0'):
        model_zoo.load_url(url1)
        with pytest.raises(RuntimeError):
            model_zoo.load_url(url2)
    else:
        # high version of PyTorch can load checkpoints from url, regardless
        # of which version they were saved in
        model_zoo.load_url(url1)
        model_zoo.load_url(url2)

    load_url(url1)
    # if a checkpoint was saved in torch >= 1.6.0 but loaded in torch < 1.5.0,
    # it will raise a RuntimeError
    if digit_version(TORCH_VERSION) < digit_version('1.5.0'):
        with pytest.raises(RuntimeError):
            load_url(url2)
    else:
        load_url(url2)