Unverified Commit d6ce3a2e authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add _load_checkpoint() to handle the loading only (#178)

parent 712651aa
# Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import (load_checkpoint, load_state_dict, save_checkpoint,
weights_to_cpu)
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only
from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, LrUpdaterHook, OptimizerHook,
......@@ -16,8 +16,8 @@ __all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist',
'get_dist_info', 'master_only'
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'parallel_test',
'Priority', 'get_priority', 'get_host_info', 'get_time_str',
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only'
]
......@@ -129,25 +129,18 @@ def get_torchvision_models():
return model_urls
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
filename (str): Either a filepath or URI.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict or OrderedDict: The loaded checkpoint.
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
# load checkpoint from modelzoo or file or url
if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
......@@ -167,6 +160,28 @@ def load_checkpoint(model,
if not osp.isfile(filename):
raise IOError('{} is not a checkpoint file'.format(filename))
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment