Commit bd6e5c4f authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

bug_fixes and small changes to masked lm (#721)

Summary:
1) Made the model compatible with using either `masked_lm_dataset` or `monolingual_dataset`.
2) fixed default args setting task. (`bert` vs `masked_lm`) myleott should we keep both?
3) bug in setting default value of `sentence_class_num`
4) bug for padding mask in `fp16`.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/721

Differential Revision: D15259885

fbshipit-source-id: 9dbf7fb8192992c1251670287bed719e41c08fcc
parent f2563c21
......@@ -260,7 +260,7 @@ class MaskedLMDataset(FairseqDataset):
"id": torch.LongTensor([s["id"] for s in samples]),
"ntokens": sum(len(s["source"]) for s in samples),
"net_input": {
"tokens": merge("source"),
"src_tokens": merge("source"),
"segment_labels": merge("segment_labels"),
},
"lm_target": merge("lm_target"),
......
......@@ -99,8 +99,8 @@ class MaskedLMModel(BaseFairseqModel):
help='Use gelu activation function in encoder'
' layer')
def forward(self, tokens, segment_labels):
return self.encoder(tokens, segment_labels)
def forward(self, src_tokens, segment_labels, **kwargs):
return self.encoder(src_tokens, segment_labels, **kwargs)
def max_positions(self):
return self.encoder.max_positions
......@@ -109,10 +109,8 @@ class MaskedLMModel(BaseFairseqModel):
def build_model(cls, args, task):
"""Build a new model instance."""
if args.task == 'bert':
base_bert_architecture(args)
else:
xlm_architecture(args)
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, 'max_positions'):
args.max_positions = args.tokens_per_sample
......@@ -178,7 +176,7 @@ class MaskedLMEncoder(FairseqEncoder):
bias=False
)
def forward(self, tokens, segment_labels, **unused):
def forward(self, src_tokens, segment_labels, **unused):
"""
Forward pass for Masked LM encoder. This first computes the token
embedding using the token embedding matrix, position embeddings (if
......@@ -188,7 +186,7 @@ class MaskedLMEncoder(FairseqEncoder):
output of the classification_token (see bert_task or cross_lingual_lm
task for more details).
Args:
- tokens: B x T matrix representing sentences
- src_tokens: B x T matrix representing sentences
- segment_labels: B x T matrix representing segment label for tokens
Returns:
- a tuple of the following:
......@@ -202,7 +200,7 @@ class MaskedLMEncoder(FairseqEncoder):
this is specified in the input arguments.
"""
inner_states, sentence_rep = self.sentence_encoder(tokens, segment_labels)
inner_states, sentence_rep = self.sentence_encoder(src_tokens, segment_labels)
x = inner_states[-1].transpose(0, 1)
# project back to size of vocabulary
......@@ -269,7 +267,7 @@ def base_architecture(args):
@register_model_architecture('masked_lm', 'bert_base')
def base_bert_architecture(args):
def bert_base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.share_encoder_input_output_embed = getattr(
args, 'share_encoder_input_output_embed', True)
......@@ -285,7 +283,7 @@ def base_bert_architecture(args):
args.no_bias_kv = getattr(args, 'no_bias_kv', True)
args.sent_loss = getattr(args, 'sent_loss', True)
args.sentence_class_num = getattr(args, 'sentence-class-num', 2)
args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
args.apply_bert_init = getattr(args, 'apply_bert_init', True)
......@@ -303,7 +301,7 @@ def bert_large_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 24)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
base_bert_architecture(args)
bert_base_architecture(args)
@register_model_architecture('masked_lm', 'xlm_base')
......
......@@ -171,7 +171,7 @@ class TransformerSentenceEncoder(nn.Module):
# account for padding while computing the representation
if padding_mask is not None:
x *= (1 - padding_mask.unsqueeze(-1).float())
x *= (1 - padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment