"docs/vscode:/vscode.git/clone" did not exist on "0b3ddec6540d7fc7fb59c1b6184a5e6c9e1d32e0"
Unverified Commit e076c8b0 authored by Ycr's avatar Ycr Committed by GitHub
Browse files

[Feature] Support revise_keys in load_checkpoint(). (#829)

* Simplified the code.

* Improved chkpt compatibility.

* One may modify the checkpoint via adding keywords.

* Tiny.

* Following reviewer's suggestion.

* Added unit_test.

* Fixed.

* Modify the state_dict  with  construction.

* Added test.

* Modified。

* Mimimalised the modification.

* Added the docstring.

* Format.

* Improved.

* Tiny.

* Temp file.

* Added assertion.

* Doc string.

* Fixed.
parent 34b552b8
......@@ -310,6 +310,7 @@ class PretrainedInit(object):
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
map_location (str): map tensors into proper locations.
"""
def __init__(self, checkpoint, prefix=None, map_location=None):
......
......@@ -103,7 +103,6 @@ class BaseRunner(metaclass=ABCMeta):
self.optimizer = optimizer
self.logger = logger
self.meta = meta
# create work_dir
if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir)
......@@ -307,10 +306,20 @@ class BaseRunner(metaclass=ABCMeta):
for hook in self._hooks:
getattr(hook, fn_name)(self)
def load_checkpoint(self, filename, map_location='cpu', strict=False):
def load_checkpoint(self,
filename,
map_location='cpu',
strict=False,
revise_keys=[(r'^module.', '')]):
self.logger.info('load checkpoint from %s', filename)
return load_checkpoint(self.model, filename, map_location, strict,
self.logger)
return load_checkpoint(
self.model,
filename,
map_location,
strict,
self.logger,
revise_keys=revise_keys)
def resume(self,
checkpoint,
......
......@@ -3,6 +3,7 @@ import io
import os
import os.path as osp
import pkgutil
import re
import time
import warnings
from collections import OrderedDict
......@@ -503,7 +504,8 @@ def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None):
logger=None,
revise_keys=[(r'^module\.', '')]):
"""Load checkpoint from a file or URI.
Args:
......@@ -515,6 +517,11 @@ def load_checkpoint(model,
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
revise_keys (list): A list of customized keywords to modify the
state_dict in checkpoint. Each item is a (pattern, replacement)
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].
Returns:
dict or OrderedDict: The loaded checkpoint.
......@@ -530,8 +537,8 @@ def load_checkpoint(model,
else:
state_dict = checkpoint
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
for p, r in revise_keys:
state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
......
......@@ -10,7 +10,8 @@ from torch.nn.parallel import DataParallel
from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_from_pavi)
get_state_dict, load_checkpoint,
load_from_pavi)
@MODULE_WRAPPERS.register_module()
......@@ -188,6 +189,38 @@ def test_load_checkpoint_with_prefix():
_load_checkpoint_with_prefix(prefix, 'model.pth')
def test_load_checkpoint():
import os
import tempfile
import re
class PrefixModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = Model()
pmodel = PrefixModel()
model = Model()
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
# add prefix
torch.save(model.state_dict(), checkpoint_path)
state_dict = load_checkpoint(
pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')])
for key in pmodel.backbone.state_dict().keys():
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
# strip prefix
torch.save(pmodel.state_dict(), checkpoint_path)
state_dict = load_checkpoint(
model, checkpoint_path, revise_keys=[(r'^backbone\.', '')])
for key in state_dict.keys():
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)
def test_load_classes_name():
import os
......
......@@ -6,6 +6,7 @@ CommandLine:
"""
import logging
import os.path as osp
import re
import shutil
import sys
import tempfile
......@@ -415,3 +416,42 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
def test_runner_with_revise_keys():
import os
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
class PrefixModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = Model()
pmodel = PrefixModel()
model = Model()
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
# add prefix
torch.save(model.state_dict(), checkpoint_path)
runner = _build_demo_runner(runner_type='EpochBasedRunner')
runner.model = pmodel
state_dict = runner.load_checkpoint(
checkpoint_path, revise_keys=[(r'^', 'backbone.')])
for key in pmodel.backbone.state_dict().keys():
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
# strip prefix
torch.save(pmodel.state_dict(), checkpoint_path)
runner.model = model
state_dict = runner.load_checkpoint(
checkpoint_path, revise_keys=[(r'^backbone\.', '')])
for key in state_dict.keys():
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment