"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "554e67cb06db8e5476b8e7a7539b8a50b7550743"
Unverified Commit 4a9f8346 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] Fix unittest in pt1.9 (#1146)

* fix test.txt

* fix unittest in pt1.9

* fix checkpoint filename error

* add comment

* fix unittest

* fix onnxruntime version
parent 6c63621a
...@@ -11,6 +11,7 @@ from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME, ...@@ -11,6 +11,7 @@ from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
_load_checkpoint, _load_checkpoint,
get_deprecated_model_names, get_deprecated_model_names,
get_external_models) get_external_models)
from mmcv.utils import TORCH_VERSION
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
...@@ -77,13 +78,23 @@ def load(filepath, map_location=None): ...@@ -77,13 +78,23 @@ def load(filepath, map_location=None):
def test_load_external_url(): def test_load_external_url():
# test modelzoo:// # test modelzoo://
url = _load_checkpoint('modelzoo://resnet50') url = _load_checkpoint('modelzoo://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \ if TORCH_VERSION < '1.9.0':
'.pth' assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
# test torchvision:// # test torchvision://
url = _load_checkpoint('torchvision://resnet50') url = _load_checkpoint('torchvision://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \ if TORCH_VERSION < '1.9.0':
'.pth' assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
# test open-mmlab:// with default MMCV_HOME # test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None) os.environ.pop(ENV_MMCV_HOME, None)
......
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel from torch.nn.parallel import DataParallel, DistributedDataParallel
...@@ -15,7 +16,7 @@ def mock(*args, **kwargs): ...@@ -15,7 +16,7 @@ def mock(*args, **kwargs):
@patch('torch.distributed._broadcast_coalesced', mock) @patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock) @patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock) @patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper(): def test_is_module_wrapper():
class Model(nn.Module): class Model(nn.Module):
...@@ -27,6 +28,12 @@ def test_is_module_wrapper(): ...@@ -27,6 +28,12 @@ def test_is_module_wrapper():
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
# _verify_model_across_ranks is added in torch1.9.0 so we should check
# wether _verify_model_across_ranks is the member of torch.distributed
# before mocking
if hasattr(torch.distributed, '_verify_model_across_ranks'):
torch.distributed._verify_model_across_ranks = mock
model = Model() model = Model()
assert not is_module_wrapper(model) assert not is_module_wrapper(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment