Unverified Commit 96892bdc authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

Refactor _load_checkpoint fn (#790)

* Refactor _load_checkpoint fn

* Update _load_checkpoint fn

* Update docs str and add unit test

* Fix unit test

* Fix lint

* Add comment and Optimize function

* Fix docs str

* Update load_ckpt and fix doc str

* Update doc str and add sort unit test

* Update and fix unit test

* Fix unit test

* Update and add unit test

* Fix openmmlab prefix error
parent 4450bd2e
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner from .builder import RUNNERS, build_runner
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict, from .checkpoint import (CheckpointLoader, _load_checkpoint, load_checkpoint,
save_checkpoint, weights_to_cpu) load_state_dict, save_checkpoint, weights_to_cpu)
from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info, from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only) init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner from .epoch_based_runner import EpochBasedRunner, Runner
...@@ -33,5 +33,6 @@ __all__ = [ ...@@ -33,5 +33,6 @@ __all__ = [
'build_optimizer', 'build_optimizer_constructor', 'IterLoader', 'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler' 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler',
'CheckpointLoader'
] ]
...@@ -105,23 +105,206 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -105,23 +105,206 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
print(err_msg) print(err_msg)
def load_url_dist(url, model_dir=None): def get_torchvision_models():
"""In distributed setting, this function only download checkpoint at local model_urls = dict()
rank 0.""" for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
class CheckpointLoader:
"""A general checkpoint loader to manage all schemes."""
_schemes = {}
@classmethod
def _register_scheme(cls, prefixes, loader, force=False):
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))
for prefix in prefixes:
if (prefix not in cls._schemes) or force:
cls._schemes[prefix] = loader
else:
raise KeyError(
f'{prefix} is already registered as a loader backend, '
'add "force=True" if you want to override it')
# sort, longer prefixes take priority
cls._schemes = OrderedDict(
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
@classmethod
def register_scheme(cls, prefixes, loader=None, force=False):
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
Args:
prefixes (str or list[str] or tuple[str]):
The prefix of the registered loader.
loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None.
Defaults to None.
force (bool, optional): Whether to override the loader
if the prefix has already been registered. Defaults to False.
"""
if loader is not None:
cls._register_scheme(prefixes, loader, force=force)
return
def _register(loader_cls):
cls._register_scheme(prefixes, loader_cls, force=force)
return loader_cls
return _register
@classmethod
def _get_checkpoint_loader(cls, path):
"""Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found.
Args:
path (str): checkpoint path
Returns:
loader (function): checkpoint loader
"""
for p in cls._schemes:
if path.startswith(p):
return cls._schemes[p]
@classmethod
def load_checkpoint(cls, filename, map_location=None, logger=None):
"""load checkpoint through URL scheme path.
Args:
filename (str): checkpoint file name with given prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
logger (:mod:`logging.Logger`, optional): The logger for message.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint_loader = cls._get_checkpoint_loader(filename)
class_name = checkpoint_loader.__name__
mmcv.print_log(f'Use {class_name} loader', logger)
return checkpoint_loader(filename, map_location)
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
"""load checkpoint by local file path.
Args:
filename (str): local checkpoint file path
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(filename, map_location=None, model_dir=None):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): it's not use.
model_dir (string, optional): directory in which to save the object,
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank)) rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0: if rank == 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir) checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
if world_size > 1: if world_size > 1:
torch.distributed.barrier() torch.distributed.barrier()
if rank > 0: if rank > 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir) checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
return checkpoint return checkpoint
def load_pavimodel_dist(model_path, map_location=None): @CheckpointLoader.register_scheme(prefixes='pavi://')
"""In distributed setting, this function only download checkpoint at local def load_from_pavi(filename, map_location=None):
rank 0.""" """load checkpoint through the file path prefixed with pavi. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
filename (str): checkpoint file path with pavi prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
assert filename.startswith('pavi://'), \
f'Expected filename startswith `pavi://`, but get {filename}'
model_path = filename[7:]
try: try:
from pavi import modelcloud from pavi import modelcloud
except ImportError: except ImportError:
...@@ -147,9 +330,20 @@ def load_pavimodel_dist(model_path, map_location=None): ...@@ -147,9 +330,20 @@ def load_pavimodel_dist(model_path, map_location=None):
return checkpoint return checkpoint
def load_fileclient_dist(filename, backend, map_location): @CheckpointLoader.register_scheme(prefixes='s3://')
"""In distributed setting, this function only download checkpoint at local def load_from_ceph(filename, map_location=None, backend='ceph'):
rank 0.""" """load checkpoint through the file path prefixed with s3. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str): The storage backend type. Options are "disk", "ceph",
"memcached" and "lmdb". Default: 'ceph'
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank)) rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph'] allowed_backends = ['ceph']
...@@ -168,118 +362,106 @@ def load_fileclient_dist(filename, backend, map_location): ...@@ -168,118 +362,106 @@ def load_fileclient_dist(filename, backend, map_location):
return checkpoint return checkpoint
def get_torchvision_models(): @CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
model_urls = dict() def load_from_torchvision(filename, map_location=None):
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): """load checkpoint through the file path prefixed with modelzoo or
if ispkg: torchvision.
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
return model_urls
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): it's not use.
def get_external_models(): Returns:
mmcv_home = _get_mmcv_home() dict or OrderedDict: The loaded checkpoint.
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') """
default_urls = load_file(default_json_path) model_urls = get_torchvision_models()
assert isinstance(default_urls, dict) if filename.startswith('modelzoo://'):
external_json_path = osp.join(mmcv_home, 'open_mmlab.json') warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
if osp.exists(external_json_path): 'use "torchvision://" instead')
external_urls = load_file(external_json_path) model_name = filename[11:]
assert isinstance(external_urls, dict) else:
default_urls.update(external_urls) model_name = filename[14:]
return load_from_http(model_urls[model_name])
return default_urls
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
def get_mmcls_models(): Args:
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') filename (str): checkpoint file path with open-mmlab or
mmcls_urls = load_file(mmcls_json_path) openmmlab prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
return mmcls_urls Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_external_models()
prefix_str = 'open-mmlab://'
if filename.startswith(prefix_str):
model_name = filename[13:]
else:
model_name = filename[12:]
prefix_str = 'openmmlab://'
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}')
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_from_http(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls @CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
"""load checkpoint through the file path prefixed with mmcls.
Args:
filename (str): checkpoint file path with mmcls prefix
map_location (str, optional): it's not use.
def _process_mmcls_checkpoint(checkpoint): Returns:
state_dict = checkpoint['state_dict'] dict or OrderedDict: The loaded checkpoint.
new_state_dict = OrderedDict() """
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_from_http(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
return checkpoint
def _load_checkpoint(filename, map_location=None): def _load_checkpoint(filename, map_location=None, logger=None):
"""Load checkpoint from somewhere (modelzoo, file, url). """Load checkpoint from somewhere (modelzoo, file, url).
Args: Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``, filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details. details.
map_location (str | None): Same as :func:`torch.load`. Default: None. map_location (str, optional): Same as :func:`torch.load`.
Default: None.
logger (:mod:`logging.Logger`, optional): The logger for error message.
Default: None
Returns: Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an dict or OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint. information, which depends on the checkpoint.
""" """
if filename.startswith('modelzoo://'): return CheckpointLoader.load_checkpoint(filename, map_location, logger)
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
model_urls = get_torchvision_models()
model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('torchvision://'):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
model_urls = get_external_models()
model_name = filename[13:]
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
f'of open-mmlab://{deprecated_urls[model_name]}')
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith('mmcls://'):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
elif filename.startswith('pavi://'):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith('s3://'):
checkpoint = load_fileclient_dist(
filename, backend='ceph', map_location=map_location)
else:
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def load_checkpoint(model, def load_checkpoint(model,
...@@ -302,7 +484,7 @@ def load_checkpoint(model, ...@@ -302,7 +484,7 @@ def load_checkpoint(model,
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
checkpoint = _load_checkpoint(filename, map_location) checkpoint = _load_checkpoint(filename, map_location, logger)
# OrderedDict is a subclass of dict # OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict): if not isinstance(checkpoint, dict):
raise RuntimeError( raise RuntimeError(
......
...@@ -58,17 +58,22 @@ def test_get_deprecated_models(): ...@@ -58,17 +58,22 @@ def test_get_deprecated_models():
} }
def load_url_dist(url): def load_from_http(url):
return 'url:' + url return 'url:' + url
def load_url(url, model_dir=None):
return load_from_http(url)
def load(filepath, map_location=None): def load(filepath, map_location=None):
return 'local:' + filepath return 'local:' + filepath
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@patch('mmcv.runner.checkpoint.load_url_dist', load_url_dist) @patch('mmcv.runner.checkpoint.load_from_http', load_from_http)
@patch('torch.load', load) @patch('torch.load', load)
@patch('torch.utils.model_zoo.load_url', load_url)
def test_load_external_url(): def test_load_external_url():
# test modelzoo:// # test modelzoo://
url = _load_checkpoint('modelzoo://resnet50') url = _load_checkpoint('modelzoo://resnet50')
...@@ -96,6 +101,16 @@ def test_load_external_url(): ...@@ -96,6 +101,16 @@ def test_load_external_url():
url = _load_checkpoint('open-mmlab://train_old') url = _load_checkpoint('open-mmlab://train_old')
assert url == 'url:https://localhost/train.pth' assert url == 'url:https://localhost/train.pth'
# test openmmlab:// with deprecated model name
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
with pytest.warns(
Warning,
match='openmmlab://train_old is deprecated in favor of '
'openmmlab://train'):
url = _load_checkpoint('openmmlab://train_old')
assert url == 'url:https://localhost/train.pth'
# test open-mmlab:// with user-defined MMCV_HOME # test open-mmlab:// with user-defined MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None) os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home') mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import get_state_dict, load_pavimodel_dist from mmcv.runner.checkpoint import get_state_dict, load_from_pavi
@MODULE_WRAPPERS.register_module() @MODULE_WRAPPERS.register_module()
...@@ -143,9 +143,13 @@ def test_load_pavimodel_dist(): ...@@ -143,9 +143,13 @@ def test_load_pavimodel_dist():
pavimodel = Mockpavimodel() pavimodel = Mockpavimodel()
import pavi import pavi
pavi.modelcloud.get = MagicMock(return_value=pavimodel) pavi.modelcloud.get = MagicMock(return_value=pavimodel)
with pytest.raises(AssertionError):
# test pavi prefix
_ = load_from_pavi('MyPaviFolder/checkpoint.pth')
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
# there is not such checkpoint for us to load # there is not such checkpoint for us to load
_ = load_pavimodel_dist('MyPaviFolder/checkpoint.pth') _ = load_from_pavi('pavi://checkpoint.pth')
def test_load_classes_name(): def test_load_classes_name():
...@@ -178,3 +182,72 @@ def test_load_classes_name(): ...@@ -178,3 +182,72 @@ def test_load_classes_name():
# remove the temp file # remove the temp file
os.remove(checkpoint_path) os.remove(checkpoint_path)
def test_checkpoint_loader():
from mmcv.runner import _load_checkpoint, save_checkpoint, CheckpointLoader
import tempfile
import os
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
model = Model()
save_checkpoint(model, checkpoint_path)
checkpoint = _load_checkpoint(checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
# remove the temp file
os.remove(checkpoint_path)
filenames = [
'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth',
'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth',
'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth',
'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth'
]
fn_names = [
'load_from_http', 'load_from_http', 'load_from_torchvision',
'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab',
'load_from_mmcls', 'load_from_pavi', 'load_from_ceph',
'load_from_local', 'load_from_local'
]
for filename, fn_name in zip(filenames, fn_names):
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == fn_name
@CheckpointLoader.register_scheme(prefixes='ftp://')
def load_from_ftp(filename, map_location):
return dict(filename=filename)
# test register_loader
filename = 'ftp://xx.xx/xx.pth'
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_ftp'
def load_from_ftp1(filename, map_location):
return dict(filename=filename)
# test duplicate registered error
with pytest.raises(KeyError):
CheckpointLoader.register_scheme('ftp://', load_from_ftp1)
# test force param
CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True)
checkpoint = CheckpointLoader.load_checkpoint(filename)
assert checkpoint['filename'] == filename
# test print function name
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_ftp1'
# test sort
@CheckpointLoader.register_scheme(prefixes='a/b')
def load_from_ab(filename, map_location):
return dict(filename=filename)
@CheckpointLoader.register_scheme(prefixes='a/b/c')
def load_from_abc(filename, map_location):
return dict(filename=filename)
filename = 'a/b/c/d'
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_abc'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment