Unverified Commit f48241a6 authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

load_checkpoint support normal dict checkpoints (#351)

* load_checkpoint support normal dict checkpoints

* comments
parent 97f9efd8
...@@ -222,14 +222,15 @@ def load_checkpoint(model, ...@@ -222,14 +222,15 @@ def load_checkpoint(model,
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
checkpoint = _load_checkpoint(filename, map_location) checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint # get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict): if 'state_dict' in checkpoint:
state_dict = checkpoint
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict'] state_dict = checkpoint['state_dict']
else: else:
raise RuntimeError( state_dict = checkpoint
f'No state_dict found in checkpoint file {filename}')
# strip prefix of state_dict # strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'): if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment