Commit dc028c52 authored by Haoran Li's avatar Haoran Li Committed by Facebook Github Bot
Browse files

fix masked_lm for loading in pytext

Summary: lm_output_learned_bias doesn't exist when loading the model for fine-tuning

Reviewed By: jingfeidu

Differential Revision: D15579190

fbshipit-source-id: 45e8e193399943c89b77cc553d3d6d49b056e55a
parent a2aed890
...@@ -174,6 +174,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -174,6 +174,7 @@ class MaskedLMEncoder(FairseqEncoder):
self.activation_fn = utils.get_activation_fn(args.activation_fn) self.activation_fn = utils.get_activation_fn(args.activation_fn)
self.layer_norm = LayerNorm(args.encoder_embed_dim) self.layer_norm = LayerNorm(args.encoder_embed_dim)
self.lm_output_learned_bias = None
if self.load_softmax: if self.load_softmax:
self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size)) self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size))
...@@ -229,6 +230,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -229,6 +230,7 @@ class MaskedLMEncoder(FairseqEncoder):
elif self.embed_out is not None: elif self.embed_out is not None:
x = self.embed_out(x) x = self.embed_out(x)
if self.lm_output_learned_bias is not None:
x = x + self.lm_output_learned_bias x = x + self.lm_output_learned_bias
sentence_logits = None sentence_logits = None
if self.sentence_projection_layer: if self.sentence_projection_layer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment