Unverified Commit 8eae7779 authored by nicheng0019's avatar nicheng0019 Committed by GitHub
Browse files

Update checkpoint.py (#722)

parent 3392a4ea
...@@ -256,7 +256,7 @@ def load_checkpoint(model, ...@@ -256,7 +256,7 @@ def load_checkpoint(model,
state_dict = checkpoint state_dict = checkpoint
# strip prefix of state_dict # strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'): if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} state_dict = {k[7:]: v for k, v in state_dict.items()}
# load state_dict # load state_dict
load_state_dict(model, state_dict, strict, logger) load_state_dict(model, state_dict, strict, logger)
return checkpoint return checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment