Unverified Commit e076c8b0 authored by Ycr's avatar Ycr Committed by GitHub
Browse files

[Feature] Support revise_keys in load_checkpoint(). (#829)

* Simplified the code.

* Improved chkpt compatibility.

* One may modify the checkpoint via adding keywords.

* Tiny.

* Following reviewer's suggestion.

* Added unit_test.

* Fixed.

* Modify the state_dict  with  construction.

* Added test.

* Modified。

* Mimimalised the modification.

* Added the docstring.

* Format.

* Improved.

* Tiny.

* Temp file.

* Added assertion.

* Doc string.

* Fixed.
parent 34b552b8
...@@ -310,6 +310,7 @@ class PretrainedInit(object): ...@@ -310,6 +310,7 @@ class PretrainedInit(object):
initialize. For example, if we would like to only load the initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``. backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None. Defaults to None.
map_location (str): map tensors into proper locations.
""" """
def __init__(self, checkpoint, prefix=None, map_location=None): def __init__(self, checkpoint, prefix=None, map_location=None):
......
...@@ -103,7 +103,6 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -103,7 +103,6 @@ class BaseRunner(metaclass=ABCMeta):
self.optimizer = optimizer self.optimizer = optimizer
self.logger = logger self.logger = logger
self.meta = meta self.meta = meta
# create work_dir # create work_dir
if mmcv.is_str(work_dir): if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir) self.work_dir = osp.abspath(work_dir)
...@@ -307,10 +306,20 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -307,10 +306,20 @@ class BaseRunner(metaclass=ABCMeta):
for hook in self._hooks: for hook in self._hooks:
getattr(hook, fn_name)(self) getattr(hook, fn_name)(self)
def load_checkpoint(self, filename, map_location='cpu', strict=False): def load_checkpoint(self,
filename,
map_location='cpu',
strict=False,
revise_keys=[(r'^module.', '')]):
self.logger.info('load checkpoint from %s', filename) self.logger.info('load checkpoint from %s', filename)
return load_checkpoint(self.model, filename, map_location, strict, return load_checkpoint(
self.logger) self.model,
filename,
map_location,
strict,
self.logger,
revise_keys=revise_keys)
def resume(self, def resume(self,
checkpoint, checkpoint,
......
...@@ -3,6 +3,7 @@ import io ...@@ -3,6 +3,7 @@ import io
import os import os
import os.path as osp import os.path as osp
import pkgutil import pkgutil
import re
import time import time
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
...@@ -503,7 +504,8 @@ def load_checkpoint(model, ...@@ -503,7 +504,8 @@ def load_checkpoint(model,
filename, filename,
map_location=None, map_location=None,
strict=False, strict=False,
logger=None): logger=None,
revise_keys=[(r'^module\.', '')]):
"""Load checkpoint from a file or URI. """Load checkpoint from a file or URI.
Args: Args:
...@@ -515,6 +517,11 @@ def load_checkpoint(model, ...@@ -515,6 +517,11 @@ def load_checkpoint(model,
strict (bool): Whether to allow different params for the model and strict (bool): Whether to allow different params for the model and
checkpoint. checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message. logger (:mod:`logging.Logger` or None): The logger for error message.
revise_keys (list): A list of customized keywords to modify the
state_dict in checkpoint. Each item is a (pattern, replacement)
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
...@@ -530,8 +537,8 @@ def load_checkpoint(model, ...@@ -530,8 +537,8 @@ def load_checkpoint(model,
else: else:
state_dict = checkpoint state_dict = checkpoint
# strip prefix of state_dict # strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'): for p, r in revise_keys:
state_dict = {k[7:]: v for k, v in state_dict.items()} state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}
# load state_dict # load state_dict
load_state_dict(model, state_dict, strict, logger) load_state_dict(model, state_dict, strict, logger)
return checkpoint return checkpoint
......
...@@ -10,7 +10,8 @@ from torch.nn.parallel import DataParallel ...@@ -10,7 +10,8 @@ from torch.nn.parallel import DataParallel
from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix, from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_from_pavi) get_state_dict, load_checkpoint,
load_from_pavi)
@MODULE_WRAPPERS.register_module() @MODULE_WRAPPERS.register_module()
...@@ -188,6 +189,38 @@ def test_load_checkpoint_with_prefix(): ...@@ -188,6 +189,38 @@ def test_load_checkpoint_with_prefix():
_load_checkpoint_with_prefix(prefix, 'model.pth') _load_checkpoint_with_prefix(prefix, 'model.pth')
def test_load_checkpoint():
import os
import tempfile
import re
class PrefixModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = Model()
pmodel = PrefixModel()
model = Model()
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
# add prefix
torch.save(model.state_dict(), checkpoint_path)
state_dict = load_checkpoint(
pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')])
for key in pmodel.backbone.state_dict().keys():
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
# strip prefix
torch.save(pmodel.state_dict(), checkpoint_path)
state_dict = load_checkpoint(
model, checkpoint_path, revise_keys=[(r'^backbone\.', '')])
for key in state_dict.keys():
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)
def test_load_classes_name(): def test_load_classes_name():
import os import os
......
...@@ -6,6 +6,7 @@ CommandLine: ...@@ -6,6 +6,7 @@ CommandLine:
""" """
import logging import logging
import os.path as osp import os.path as osp
import re
import shutil import shutil
import sys import sys
import tempfile import tempfile
...@@ -415,3 +416,42 @@ def _build_demo_runner(runner_type='EpochBasedRunner', ...@@ -415,3 +416,42 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
runner.register_checkpoint_hook(dict(interval=1)) runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config) runner.register_logger_hooks(log_config)
return runner return runner
def test_runner_with_revise_keys():
import os
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
class PrefixModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = Model()
pmodel = PrefixModel()
model = Model()
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
# add prefix
torch.save(model.state_dict(), checkpoint_path)
runner = _build_demo_runner(runner_type='EpochBasedRunner')
runner.model = pmodel
state_dict = runner.load_checkpoint(
checkpoint_path, revise_keys=[(r'^', 'backbone.')])
for key in pmodel.backbone.state_dict().keys():
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
# strip prefix
torch.save(pmodel.state_dict(), checkpoint_path)
runner.model = model
state_dict = runner.load_checkpoint(
checkpoint_path, revise_keys=[(r'^backbone\.', '')])
for key in state_dict.keys():
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment