Commit c6f2caea authored by CoinCheung's avatar CoinCheung Committed by Kai Chen
Browse files

resubmit parameter shape check (#100)

* resubmit parameter shape check

* check lint

* print a table of mismatched keys
parent d1abbca1
...@@ -8,6 +8,7 @@ from importlib import import_module ...@@ -8,6 +8,7 @@ from importlib import import_module
import torch import torch
import torchvision import torchvision
from terminaltables import AsciiTable
from torch.utils import model_zoo from torch.utils import model_zoo
import mmcv import mmcv
...@@ -55,6 +56,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -55,6 +56,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
message. If not specified, print function will be used. message. If not specified, print function will be used.
""" """
unexpected_keys = [] unexpected_keys = []
shape_mismatch_pairs = []
own_state = module.state_dict() own_state = module.state_dict()
for name, param in state_dict.items(): for name, param in state_dict.items():
if name not in own_state: if name not in own_state:
...@@ -63,16 +66,18 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -63,16 +66,18 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if isinstance(param, torch.nn.Parameter): if isinstance(param, torch.nn.Parameter):
# backwards compatibility for serialized parameters # backwards compatibility for serialized parameters
param = param.data param = param.data
if param.size() != own_state[name].size():
try: shape_mismatch_pairs.append(
[name, own_state[name].size(),
param.size()])
continue
own_state[name].copy_(param) own_state[name].copy_(param)
except Exception:
raise RuntimeError( all_missing_keys = set(own_state.keys()) - set(state_dict.keys())
'While copying the parameter named {}, ' # ignore "num_batches_tracked" of BN layers
'whose dimensions in the model are {} and ' missing_keys = [
'whose dimensions in the checkpoint are {}.'.format( key for key in all_missing_keys if 'num_batches_tracked' not in key
name, own_state[name].size(), param.size())) ]
missing_keys = set(own_state.keys()) - set(state_dict.keys())
err_msg = [] err_msg = []
if unexpected_keys: if unexpected_keys:
...@@ -81,8 +86,17 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -81,8 +86,17 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if missing_keys: if missing_keys:
err_msg.append('missing keys in source state_dict: {}\n'.format( err_msg.append('missing keys in source state_dict: {}\n'.format(
', '.join(missing_keys))) ', '.join(missing_keys)))
if shape_mismatch_pairs:
mismatch_info = 'these keys have mismatched shape:\n'
header = ['key', 'expected shape', 'loaded shape']
table_data = [header] + shape_mismatch_pairs
table = AsciiTable(table_data)
err_msg.append(mismatch_info + table.table)
if len(err_msg) > 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg) err_msg = '\n'.join(err_msg)
if err_msg:
if strict: if strict:
raise RuntimeError(err_msg) raise RuntimeError(err_msg)
elif logger is not None: elif logger is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment