"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3e35628873e5a1723fdbb84a8789e99f243b4858"
Unverified Commit 250fadc2 authored by Jiazhen Wang's avatar Jiazhen Wang Committed by GitHub
Browse files

[Fix] Fix checkpoint local files detect (#1549)

* [fix] fix checkpoint local files detect

* [Fix] add support for path like '~/xx/file'

* [Fix] fix some details

* [Fix] fix unittest
parent 84c7dc34
...@@ -262,9 +262,9 @@ def load_from_local(filename, map_location): ...@@ -262,9 +262,9 @@ def load_from_local(filename, map_location):
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
filename = osp.expanduser(filename)
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file') raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location) checkpoint = torch.load(filename, map_location=map_location)
return checkpoint return checkpoint
...@@ -432,7 +432,7 @@ def load_from_openmmlab(filename, map_location=None): ...@@ -432,7 +432,7 @@ def load_from_openmmlab(filename, map_location=None):
else: else:
filename = osp.join(_get_mmcv_home(), model_url) filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file') raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location) checkpoint = torch.load(filename, map_location=map_location)
return checkpoint return checkpoint
...@@ -692,8 +692,7 @@ def save_checkpoint(model, ...@@ -692,8 +692,7 @@ def save_checkpoint(model,
'file_client_args should be "None" if filename starts with' 'file_client_args should be "None" if filename starts with'
f'"pavi://", but got {file_client_args}') f'"pavi://", but got {file_client_args}')
try: try:
from pavi import modelcloud from pavi import exception, modelcloud
from pavi import exception
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Please install pavi to load checkpoint from modelcloud.') 'Please install pavi to load checkpoint from modelcloud.')
......
...@@ -128,7 +128,7 @@ def test_load_external_url(): ...@@ -128,7 +128,7 @@ def test_load_external_url():
os.environ[ENV_MMCV_HOME] = mmcv_home os.environ[ENV_MMCV_HOME] = mmcv_home
url = _load_checkpoint('open-mmlab://train') url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth' assert url == 'url:https://localhost/train.pth'
with pytest.raises(IOError, match='train.pth is not a checkpoint file'): with pytest.raises(FileNotFoundError, match='train.pth can not be found.'):
_load_checkpoint('open-mmlab://train_empty') _load_checkpoint('open-mmlab://train_empty')
url = _load_checkpoint('open-mmlab://test') url = _load_checkpoint('open-mmlab://test')
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}' assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
...@@ -140,7 +140,7 @@ def test_load_external_url(): ...@@ -140,7 +140,7 @@ def test_load_external_url():
assert url == 'url:http://localhost/train.pth' assert url == 'url:http://localhost/train.pth'
# test local file # test local file
with pytest.raises(IOError, match='train.pth is not a checkpoint file'): with pytest.raises(FileNotFoundError, match='train.pth can not be found.'):
_load_checkpoint('train.pth') _load_checkpoint('train.pth')
url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth')) url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}' assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
...@@ -13,7 +13,8 @@ from mmcv.fileio.file_client import PetrelBackend ...@@ -13,7 +13,8 @@ from mmcv.fileio.file_client import PetrelBackend
from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix, from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_checkpoint, get_state_dict, load_checkpoint,
load_from_pavi, save_checkpoint) load_from_local, load_from_pavi,
save_checkpoint)
sys.modules['petrel_client'] = MagicMock() sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock() sys.modules['petrel_client.client'] = MagicMock()
...@@ -196,8 +197,9 @@ def test_load_checkpoint_with_prefix(): ...@@ -196,8 +197,9 @@ def test_load_checkpoint_with_prefix():
def test_load_checkpoint(): def test_load_checkpoint():
import os import os
import tempfile
import re import re
import tempfile
class PrefixModel(nn.Module): class PrefixModel(nn.Module):
...@@ -228,6 +230,7 @@ def test_load_checkpoint(): ...@@ -228,6 +230,7 @@ def test_load_checkpoint():
def test_load_checkpoint_metadata(): def test_load_checkpoint_metadata():
import os import os
import tempfile import tempfile
from mmcv.runner import load_checkpoint, save_checkpoint from mmcv.runner import load_checkpoint, save_checkpoint
...@@ -331,9 +334,11 @@ def test_load_classes_name(): ...@@ -331,9 +334,11 @@ def test_load_classes_name():
def test_checkpoint_loader(): def test_checkpoint_loader():
from mmcv.runner import _load_checkpoint, save_checkpoint, CheckpointLoader
import tempfile
import os import os
import tempfile
from mmcv.runner import CheckpointLoader, _load_checkpoint, save_checkpoint
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
model = Model() model = Model()
save_checkpoint(model, checkpoint_path) save_checkpoint(model, checkpoint_path)
...@@ -433,3 +438,18 @@ def test_save_checkpoint(tmp_path): ...@@ -433,3 +438,18 @@ def test_save_checkpoint(tmp_path):
save_checkpoint( save_checkpoint(
model, filename, file_client_args={'backend': 'petrel'}) model, filename, file_client_args={'backend': 'petrel'})
mock_method.assert_called() mock_method.assert_called()
def test_load_from_local():
import os
home_path = os.path.expanduser('~')
checkpoint_path = os.path.join(
home_path, 'dummy_checkpoint_used_to_test_load_from_local.pth')
model = Model()
save_checkpoint(model, checkpoint_path)
checkpoint = load_from_local(
'~/dummy_checkpoint_used_to_test_load_from_local.pth',
map_location=None)
assert_tensor_equal(checkpoint['state_dict']['block.conv.weight'],
model.block.conv.weight)
os.remove(checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment