Unverified Commit d6ce3a2e authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add _load_checkpoint() to handle the loading only (#178)

parent 712651aa
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import (load_checkpoint, load_state_dict, save_checkpoint, from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
weights_to_cpu) save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only from .dist_utils import get_dist_info, init_dist, master_only
from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, Hook, from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, LrUpdaterHook, OptimizerHook, IterTimerHook, LoggerHook, LrUpdaterHook, OptimizerHook,
...@@ -16,8 +16,8 @@ __all__ = [ ...@@ -16,8 +16,8 @@ __all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook', 'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'parallel_test',
'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist', 'Priority', 'get_priority', 'get_host_info', 'get_time_str',
'get_dist_info', 'master_only' 'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only'
] ]
...@@ -129,25 +129,18 @@ def get_torchvision_models(): ...@@ -129,25 +129,18 @@ def get_torchvision_models():
return model_urls return model_urls
def load_checkpoint(model, def _load_checkpoint(filename, map_location=None):
filename, """Load checkpoint from somewhere (modelzoo, file, url).
map_location=None,
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args: Args:
model (Module): Module to load checkpoint. filename (str): Either a filepath or URI.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx. map_location (str | None): Same as :func:`torch.load`. Default: None.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
""" """
# load checkpoint from modelzoo or file or url
if filename.startswith('modelzoo://'): if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead') 'use "torchvision://" instead')
...@@ -167,6 +160,28 @@ def load_checkpoint(model, ...@@ -167,6 +160,28 @@ def load_checkpoint(model,
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError('{} is not a checkpoint file'.format(filename)) raise IOError('{} is not a checkpoint file'.format(filename))
checkpoint = torch.load(filename, map_location=map_location) checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# get state_dict from checkpoint # get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict): if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint state_dict = checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment