Unverified Commit 96892bdc authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

Refactor _load_checkpoint fn (#790)

* Refactor _load_checkpoint fn

* Update _load_checkpoint fn

* Update docs str and add unit test

* Fix unit test

* Fix lint

* Add comment and Optimize function

* Fix docs str

* Update load_ckpt and fix doc str

* Update doc str and add sort unit test

* Update and fix unit test

* Fix unit test

* Update and add unit test

* Fix openmmlab prefix error
parent 4450bd2e
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner from .builder import RUNNERS, build_runner
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict, from .checkpoint import (CheckpointLoader, _load_checkpoint, load_checkpoint,
save_checkpoint, weights_to_cpu) load_state_dict, save_checkpoint, weights_to_cpu)
from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info, from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only) init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner from .epoch_based_runner import EpochBasedRunner, Runner
...@@ -33,5 +33,6 @@ __all__ = [ ...@@ -33,5 +33,6 @@ __all__ = [
'build_optimizer', 'build_optimizer_constructor', 'IterLoader', 'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler' 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler',
'CheckpointLoader'
] ]
...@@ -105,69 +105,6 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -105,69 +105,6 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
print(err_msg) print(err_msg)
def load_url_dist(url, model_dir=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
return checkpoint
def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint
def load_fileclient_dist(filename, backend, map_location):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
def get_torchvision_models(): def get_torchvision_models():
model_urls = dict() model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
...@@ -221,67 +158,312 @@ def _process_mmcls_checkpoint(checkpoint): ...@@ -221,67 +158,312 @@ def _process_mmcls_checkpoint(checkpoint):
return new_checkpoint return new_checkpoint
def _load_checkpoint(filename, map_location=None): class CheckpointLoader:
"""Load checkpoint from somewhere (modelzoo, file, url). """A general checkpoint loader to manage all schemes."""
_schemes = {}
@classmethod
def _register_scheme(cls, prefixes, loader, force=False):
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))
for prefix in prefixes:
if (prefix not in cls._schemes) or force:
cls._schemes[prefix] = loader
else:
raise KeyError(
f'{prefix} is already registered as a loader backend, '
'add "force=True" if you want to override it')
# sort, longer prefixes take priority
cls._schemes = OrderedDict(
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
@classmethod
def register_scheme(cls, prefixes, loader=None, force=False):
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
Args: Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``, prefixes (str or list[str] or tuple[str]):
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for The prefix of the registered loader.
details. loader (function, optional): The loader function to be registered.
map_location (str | None): Same as :func:`torch.load`. Default: None. When this method is used as a decorator, loader is None.
Defaults to None.
force (bool, optional): Whether to override the loader
if the prefix has already been registered. Defaults to False.
"""
if loader is not None:
cls._register_scheme(prefixes, loader, force=force)
return
def _register(loader_cls):
cls._register_scheme(prefixes, loader_cls, force=force)
return loader_cls
return _register
@classmethod
def _get_checkpoint_loader(cls, path):
"""Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found.
Args:
path (str): checkpoint path
Returns: Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an loader (function): checkpoint loader
OrderedDict storing model weights or a dict containing other """
information, which depends on the checkpoint.
for p in cls._schemes:
if path.startswith(p):
return cls._schemes[p]
@classmethod
def load_checkpoint(cls, filename, map_location=None, logger=None):
"""load checkpoint through URL scheme path.
Args:
filename (str): checkpoint file name with given prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
logger (:mod:`logging.Logger`, optional): The logger for message.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint_loader = cls._get_checkpoint_loader(filename)
class_name = checkpoint_loader.__name__
mmcv.print_log(f'Use {class_name} loader', logger)
return checkpoint_loader(filename, map_location)
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
"""load checkpoint by local file path.
Args:
filename (str): local checkpoint file path
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
""" """
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(filename, map_location=None, model_dir=None):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): it's not use.
model_dir (string, optional): directory in which to save the object,
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
filename (str): checkpoint file path with pavi prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
assert filename.startswith('pavi://'), \
f'Expected filename startswith `pavi://`, but get {filename}'
model_path = filename[7:]
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='s3://')
def load_from_ceph(filename, map_location=None, backend='ceph'):
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str): The storage backend type. Options are "disk", "ceph",
"memcached" and "lmdb". Default: 'ceph'
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None):
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): it's not use.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_torchvision_models()
if filename.startswith('modelzoo://'): if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead') 'use "torchvision://" instead')
model_urls = get_torchvision_models()
model_name = filename[11:] model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name]) else:
elif filename.startswith('torchvision://'):
model_urls = get_torchvision_models()
model_name = filename[14:] model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name]) return load_from_http(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
Args:
filename (str): checkpoint file path with open-mmlab or
openmmlab prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_external_models() model_urls = get_external_models()
prefix_str = 'open-mmlab://'
if filename.startswith(prefix_str):
model_name = filename[13:] model_name = filename[13:]
else:
model_name = filename[12:]
prefix_str = 'openmmlab://'
deprecated_urls = get_deprecated_model_names() deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls: if model_name in deprecated_urls:
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
f'of open-mmlab://{deprecated_urls[model_name]}') f'of {prefix_str}{deprecated_urls[model_name]}')
model_name = deprecated_urls[model_name] model_name = deprecated_urls[model_name]
model_url = model_urls[model_name] model_url = model_urls[model_name]
# check if is url # check if is url
if model_url.startswith(('http://', 'https://')): if model_url.startswith(('http://', 'https://')):
checkpoint = load_url_dist(model_url) checkpoint = load_from_http(model_url)
else: else:
filename = osp.join(_get_mmcv_home(), model_url) filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file') raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location) checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith('mmcls://'): return checkpoint
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
"""load checkpoint through the file path prefixed with mmcls.
Args:
filename (str): checkpoint file path with mmcls prefix
map_location (str, optional): it's not use.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_mmcls_models() model_urls = get_mmcls_models()
model_name = filename[8:] model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name]) checkpoint = load_from_http(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint) checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
elif filename.startswith('pavi://'):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith('s3://'):
checkpoint = load_fileclient_dist(
filename, backend='ceph', map_location=map_location)
else:
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint return checkpoint
def _load_checkpoint(filename, map_location=None, logger=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str, optional): Same as :func:`torch.load`.
Default: None.
logger (:mod:`logging.Logger`, optional): The logger for error message.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
return CheckpointLoader.load_checkpoint(filename, map_location, logger)
def load_checkpoint(model, def load_checkpoint(model,
filename, filename,
map_location=None, map_location=None,
...@@ -302,7 +484,7 @@ def load_checkpoint(model, ...@@ -302,7 +484,7 @@ def load_checkpoint(model,
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
checkpoint = _load_checkpoint(filename, map_location) checkpoint = _load_checkpoint(filename, map_location, logger)
# OrderedDict is a subclass of dict # OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict): if not isinstance(checkpoint, dict):
raise RuntimeError( raise RuntimeError(
......
...@@ -58,17 +58,22 @@ def test_get_deprecated_models(): ...@@ -58,17 +58,22 @@ def test_get_deprecated_models():
} }
def load_url_dist(url): def load_from_http(url):
return 'url:' + url return 'url:' + url
def load_url(url, model_dir=None):
return load_from_http(url)
def load(filepath, map_location=None): def load(filepath, map_location=None):
return 'local:' + filepath return 'local:' + filepath
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@patch('mmcv.runner.checkpoint.load_url_dist', load_url_dist) @patch('mmcv.runner.checkpoint.load_from_http', load_from_http)
@patch('torch.load', load) @patch('torch.load', load)
@patch('torch.utils.model_zoo.load_url', load_url)
def test_load_external_url(): def test_load_external_url():
# test modelzoo:// # test modelzoo://
url = _load_checkpoint('modelzoo://resnet50') url = _load_checkpoint('modelzoo://resnet50')
...@@ -96,6 +101,16 @@ def test_load_external_url(): ...@@ -96,6 +101,16 @@ def test_load_external_url():
url = _load_checkpoint('open-mmlab://train_old') url = _load_checkpoint('open-mmlab://train_old')
assert url == 'url:https://localhost/train.pth' assert url == 'url:https://localhost/train.pth'
# test openmmlab:// with deprecated model name
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
with pytest.warns(
Warning,
match='openmmlab://train_old is deprecated in favor of '
'openmmlab://train'):
url = _load_checkpoint('openmmlab://train_old')
assert url == 'url:https://localhost/train.pth'
# test open-mmlab:// with user-defined MMCV_HOME # test open-mmlab:// with user-defined MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None) os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home') mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import get_state_dict, load_pavimodel_dist from mmcv.runner.checkpoint import get_state_dict, load_from_pavi
@MODULE_WRAPPERS.register_module() @MODULE_WRAPPERS.register_module()
...@@ -143,9 +143,13 @@ def test_load_pavimodel_dist(): ...@@ -143,9 +143,13 @@ def test_load_pavimodel_dist():
pavimodel = Mockpavimodel() pavimodel = Mockpavimodel()
import pavi import pavi
pavi.modelcloud.get = MagicMock(return_value=pavimodel) pavi.modelcloud.get = MagicMock(return_value=pavimodel)
with pytest.raises(AssertionError):
# test pavi prefix
_ = load_from_pavi('MyPaviFolder/checkpoint.pth')
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
# there is not such checkpoint for us to load # there is not such checkpoint for us to load
_ = load_pavimodel_dist('MyPaviFolder/checkpoint.pth') _ = load_from_pavi('pavi://checkpoint.pth')
def test_load_classes_name(): def test_load_classes_name():
...@@ -178,3 +182,72 @@ def test_load_classes_name(): ...@@ -178,3 +182,72 @@ def test_load_classes_name():
# remove the temp file # remove the temp file
os.remove(checkpoint_path) os.remove(checkpoint_path)
def test_checkpoint_loader():
from mmcv.runner import _load_checkpoint, save_checkpoint, CheckpointLoader
import tempfile
import os
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
model = Model()
save_checkpoint(model, checkpoint_path)
checkpoint = _load_checkpoint(checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
# remove the temp file
os.remove(checkpoint_path)
filenames = [
'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth',
'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth',
'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth',
'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth'
]
fn_names = [
'load_from_http', 'load_from_http', 'load_from_torchvision',
'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab',
'load_from_mmcls', 'load_from_pavi', 'load_from_ceph',
'load_from_local', 'load_from_local'
]
for filename, fn_name in zip(filenames, fn_names):
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == fn_name
@CheckpointLoader.register_scheme(prefixes='ftp://')
def load_from_ftp(filename, map_location):
return dict(filename=filename)
# test register_loader
filename = 'ftp://xx.xx/xx.pth'
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_ftp'
def load_from_ftp1(filename, map_location):
return dict(filename=filename)
# test duplicate registered error
with pytest.raises(KeyError):
CheckpointLoader.register_scheme('ftp://', load_from_ftp1)
# test force param
CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True)
checkpoint = CheckpointLoader.load_checkpoint(filename)
assert checkpoint['filename'] == filename
# test print function name
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_ftp1'
# test sort
@CheckpointLoader.register_scheme(prefixes='a/b')
def load_from_ab(filename, map_location):
return dict(filename=filename)
@CheckpointLoader.register_scheme(prefixes='a/b/c')
def load_from_abc(filename, map_location):
return dict(filename=filename)
filename = 'a/b/c/d'
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_abc'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment