Unverified Commit f5c28a6e authored by Ma Zerun's avatar Ma Zerun Committed by GitHub
Browse files

[Fix] Fix missing `state_dict._metadata` when saving and loading checkpoints. (#1294)

* Fix missing `state_dict._metadata` when saving & loading checkpoints.

* Add unit tests.

* Fix default value and variable names in unit tests.
parent 9016fef5
......@@ -520,9 +520,16 @@ def load_checkpoint(model,
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# strip prefix of state_dict
metadata = getattr(state_dict, '_metadata', OrderedDict())
for p, r in revise_keys:
state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}
state_dict = OrderedDict(
{re.sub(p, r, k): v
for k, v in state_dict.items()})
# Keep metadata in state_dict
state_dict._metadata = metadata
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
......@@ -540,6 +547,8 @@ def weights_to_cpu(state_dict):
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
return state_dict_cpu
......
......@@ -221,6 +221,76 @@ def test_load_checkpoint():
os.remove(checkpoint_path)
def test_load_checkpoint_metadata():
import os
import tempfile
from mmcv.runner import load_checkpoint, save_checkpoint
class ModelV1(nn.Module):
def __init__(self):
super().__init__()
self.block = Block()
self.conv1 = nn.Conv2d(3, 3, 1)
self.conv2 = nn.Conv2d(3, 3, 1)
nn.init.normal_(self.conv1.weight)
nn.init.normal_(self.conv2.weight)
class ModelV2(nn.Module):
_version = 2
def __init__(self):
super().__init__()
self.block = Block()
self.conv0 = nn.Conv2d(3, 3, 1)
self.conv1 = nn.Conv2d(3, 3, 1)
nn.init.normal_(self.conv0.weight)
nn.init.normal_(self.conv1.weight)
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
*args, **kwargs):
"""load checkpoints."""
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)
if version is None or version < 2:
state_dict_keys = list(state_dict.keys())
convert_map = {'conv1': 'conv0', 'conv2': 'conv1'}
for k in state_dict_keys:
for ori_str, new_str in convert_map.items():
if k.startswith(prefix + ori_str):
new_key = k.replace(ori_str, new_str)
state_dict[new_key] = state_dict[k]
del state_dict[k]
super()._load_from_state_dict(state_dict, prefix, local_metadata,
*args, **kwargs)
model_v1 = ModelV1()
model_v1_conv0_weight = model_v1.conv1.weight.detach()
model_v1_conv1_weight = model_v1.conv2.weight.detach()
model_v2 = ModelV2()
model_v2_conv0_weight = model_v2.conv0.weight.detach()
model_v2_conv1_weight = model_v2.conv1.weight.detach()
ckpt_v1_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v1.pth')
ckpt_v2_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v2.pth')
# Save checkpoint
save_checkpoint(model_v1, ckpt_v1_path)
save_checkpoint(model_v2, ckpt_v2_path)
# test load v1 model
load_checkpoint(model_v2, ckpt_v1_path)
assert torch.allclose(model_v2.conv0.weight, model_v1_conv0_weight)
assert torch.allclose(model_v2.conv1.weight, model_v1_conv1_weight)
# test load v2 model
load_checkpoint(model_v2, ckpt_v2_path)
assert torch.allclose(model_v2.conv0.weight, model_v2_conv0_weight)
assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight)
def test_load_classes_name():
import os
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment