Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
d6ce3a2e
Unverified
Commit
d6ce3a2e
authored
Feb 03, 2020
by
Kai Chen
Committed by
GitHub
Feb 03, 2020
Browse files
add _load_checkpoint() to handle the loading only (#178)
parent
712651aa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
20 deletions
+35
-20
mmcv/runner/__init__.py
mmcv/runner/__init__.py
+6
-6
mmcv/runner/checkpoint.py
mmcv/runner/checkpoint.py
+29
-14
No files found.
mmcv/runner/__init__.py
View file @
d6ce3a2e
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
.checkpoint
import
(
load_checkpoint
,
load_state_dict
,
save_checkpoint
,
from
.checkpoint
import
(
_load_checkpoint
,
load_checkpoint
,
load_state_dict
,
weights_to_cpu
)
save_checkpoint
,
weights_to_cpu
)
from
.dist_utils
import
get_dist_info
,
init_dist
,
master_only
from
.dist_utils
import
get_dist_info
,
init_dist
,
master_only
from
.hooks
import
(
CheckpointHook
,
ClosureHook
,
DistSamplerSeedHook
,
Hook
,
from
.hooks
import
(
CheckpointHook
,
ClosureHook
,
DistSamplerSeedHook
,
Hook
,
IterTimerHook
,
LoggerHook
,
LrUpdaterHook
,
OptimizerHook
,
IterTimerHook
,
LoggerHook
,
LrUpdaterHook
,
OptimizerHook
,
...
@@ -16,8 +16,8 @@ __all__ = [
...
@@ -16,8 +16,8 @@ __all__ = [
'Runner'
,
'LogBuffer'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'Runner'
,
'LogBuffer'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LrUpdaterHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LoggerHook'
,
'TextLoggerHook'
,
'PaviLoggerHook'
,
'TensorboardLoggerHook'
,
'LoggerHook'
,
'TextLoggerHook'
,
'PaviLoggerHook'
,
'TensorboardLoggerHook'
,
'WandbLoggerHook'
,
'
load_state_dict'
,
'
load_checkpoint'
,
'
weights_to_cpu
'
,
'WandbLoggerHook'
,
'
_
load_checkpoint'
,
'
load_state_dict
'
,
'
save
_checkpoint'
,
'
parallel_test'
,
'Priority'
,
'get_priority
'
,
'
load
_checkpoint'
,
'
weights_to_cpu'
,
'save_checkpoint'
,
'parallel_test
'
,
'
get_host_info'
,
'get_time_str'
,
'obj_from_dict'
,
'init_di
st'
,
'
Priority'
,
'get_priority'
,
'get_host_info'
,
'get_time_
st
r
'
,
'get_dist_info'
,
'master_only'
'obj_from_dict'
,
'init_dist'
,
'get_dist_info'
,
'master_only'
]
]
mmcv/runner/checkpoint.py
View file @
d6ce3a2e
...
@@ -129,25 +129,18 @@ def get_torchvision_models():
...
@@ -129,25 +129,18 @@ def get_torchvision_models():
return
model_urls
return
model_urls
def
load_checkpoint
(
model
,
def
_load_checkpoint
(
filename
,
map_location
=
None
):
filename
,
"""Load checkpoint from somewhere (modelzoo, file, url).
map_location
=
None
,
strict
=
False
,
logger
=
None
):
"""Load checkpoint from a file or URI.
Args:
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URI.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str | None): Same as :func:`torch.load`. Default: None.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
Returns:
dict or OrderedDict: The loaded checkpoint.
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
"""
# load checkpoint from modelzoo or file or url
if
filename
.
startswith
(
'modelzoo://'
):
if
filename
.
startswith
(
'modelzoo://'
):
warnings
.
warn
(
'The URL scheme of "modelzoo://" is deprecated, please '
warnings
.
warn
(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead'
)
'use "torchvision://" instead'
)
...
@@ -167,6 +160,28 @@ def load_checkpoint(model,
...
@@ -167,6 +160,28 @@ def load_checkpoint(model,
if
not
osp
.
isfile
(
filename
):
if
not
osp
.
isfile
(
filename
):
raise
IOError
(
'{} is not a checkpoint file'
.
format
(
filename
))
raise
IOError
(
'{} is not a checkpoint file'
.
format
(
filename
))
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
return
checkpoint
def
load_checkpoint
(
model
,
filename
,
map_location
=
None
,
strict
=
False
,
logger
=
None
):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint
=
_load_checkpoint
(
filename
,
map_location
)
# get state_dict from checkpoint
# get state_dict from checkpoint
if
isinstance
(
checkpoint
,
OrderedDict
):
if
isinstance
(
checkpoint
,
OrderedDict
):
state_dict
=
checkpoint
state_dict
=
checkpoint
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment