"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "610a71d7d467e4cc892c824db882071ff0d282e1"
Commit fca32e05 authored by Jingfei Du's avatar Jingfei Du Committed by Facebook Github Bot
Browse files

fixed bugs of masked_lm for fine-tuning (#744)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/744

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/587

After we added additional prediciton layers for language model predictions. The fine-tuning is broken because of 2 reasons.
1. checkpoint cannot be loaded since we didn't update state_dict names
2. lm_output_learned_bias is not initialize if load_softmax is false

Reviewed By: myleott

Differential Revision: D15377380

fbshipit-source-id: d58544b1d2c549586abef42fec19ec8bf27a994a
parent e2a0b87d
...@@ -159,6 +159,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -159,6 +159,7 @@ class MaskedLMEncoder(FairseqEncoder):
self.embed_out = None self.embed_out = None
self.sentence_projection_layer = None self.sentence_projection_layer = None
self.sentence_out_dim = args.sentence_class_num self.sentence_out_dim = args.sentence_class_num
self.lm_output_learned_bias = None
# Remove head is set to true during fine-tuning # Remove head is set to true during fine-tuning
self.load_softmax = not getattr(args, 'remove_head', False) self.load_softmax = not getattr(args, 'remove_head', False)
...@@ -252,7 +253,11 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -252,7 +253,11 @@ class MaskedLMEncoder(FairseqEncoder):
] = torch.FloatTensor(1) ] = torch.FloatTensor(1)
if not self.load_softmax: if not self.load_softmax:
for k in list(state_dict.keys()): for k in list(state_dict.keys()):
if "embed_out.weight" in k or "sentence_projection_layer.weight" in k: if (
"embed_out.weight" in k or
"sentence_projection_layer.weight" in k or
"lm_output_learned_bias" in k
):
del state_dict[k] del state_dict[k]
return state_dict return state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment