Commit 9a1038f6 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

fixed reloading from checkpoint (#811)

Summary:
Tested by starting training from (a) `roberta.large`, (b) `roberta.large.mnli`, (c) `checkpoints/checkpoint_last.pt`
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/811

Reviewed By: myleott

Differential Revision: D16689528

Pulled By: myleott

fbshipit-source-id: 849d72ede9d526c34b4753c1bffd689554d1f837
parent a9eda736
......@@ -128,11 +128,15 @@ class RobertaModel(FairseqLanguageModel):
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else ''
current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys()
# recreate any classification heads present in the state dict
keys_to_delete = []
# Delete any heads present in state_dict, that are not in current constructed model.
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
num_classes = state_dict[
prefix + 'classification_heads.' + head_name + '.out_proj.weight'
......@@ -140,7 +144,19 @@ class RobertaModel(FairseqLanguageModel):
inner_dim = state_dict[
prefix + 'classification_heads.' + head_name + '.dense.weight'
].size(0)
self.register_classification_head(head_name, num_classes, inner_dim)
if head_name not in current_head_names:
print("WARNING: deleting classification head ({}) from checkpoint not present in current model: {}".format(head_name, k))
keys_to_delete.append(k)
elif (
num_classes != self.classification_heads[head_name].out_proj.out_features
or inner_dim != self.classification_heads[head_name].dense.out_features
):
print("WARNING: deleting classification head ({}) from checkpoint with different dimensions than current model: {}".format(head_name, k))
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
# Copy any newly-added classification heads into the state dict
# with their current weights.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment