Commit c6f2caea authored by CoinCheung's avatar CoinCheung Committed by Kai Chen
Browse files

resubmit parameter shape check (#100)

* resubmit parameter shape check

* check lint

* print a table of mismatched keys
parent d1abbca1
......@@ -8,6 +8,7 @@ from importlib import import_module
import torch
import torchvision
from terminaltables import AsciiTable
from torch.utils import model_zoo
import mmcv
......@@ -55,6 +56,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
message. If not specified, print function will be used.
"""
unexpected_keys = []
shape_mismatch_pairs = []
own_state = module.state_dict()
for name, param in state_dict.items():
if name not in own_state:
......@@ -63,16 +66,18 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if isinstance(param, torch.nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
if param.size() != own_state[name].size():
shape_mismatch_pairs.append(
[name, own_state[name].size(),
param.size()])
continue
own_state[name].copy_(param)
try:
own_state[name].copy_(param)
except Exception:
raise RuntimeError(
'While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'.format(
name, own_state[name].size(), param.size()))
missing_keys = set(own_state.keys()) - set(state_dict.keys())
all_missing_keys = set(own_state.keys()) - set(state_dict.keys())
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
err_msg = []
if unexpected_keys:
......@@ -81,8 +86,17 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if missing_keys:
err_msg.append('missing keys in source state_dict: {}\n'.format(
', '.join(missing_keys)))
err_msg = '\n'.join(err_msg)
if err_msg:
if shape_mismatch_pairs:
mismatch_info = 'these keys have mismatched shape:\n'
header = ['key', 'expected shape', 'loaded shape']
table_data = [header] + shape_mismatch_pairs
table = AsciiTable(table_data)
err_msg.append(mismatch_info + table.table)
if len(err_msg) > 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment