Commit bd6e5c4f authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

bug_fixes and small changes to masked lm (#721)

Summary:
1) Made the model compatible with using either `masked_lm_dataset` or `monolingual_dataset`.
2) fixed default args setting task. (`bert` vs `masked_lm`) myleott should we keep both?
3) bug in setting default value of `sentence_class_num`
4) bug for padding mask in `fp16`.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/721

Differential Revision: D15259885

fbshipit-source-id: 9dbf7fb8192992c1251670287bed719e41c08fcc
parent f2563c21
...@@ -260,7 +260,7 @@ class MaskedLMDataset(FairseqDataset): ...@@ -260,7 +260,7 @@ class MaskedLMDataset(FairseqDataset):
"id": torch.LongTensor([s["id"] for s in samples]), "id": torch.LongTensor([s["id"] for s in samples]),
"ntokens": sum(len(s["source"]) for s in samples), "ntokens": sum(len(s["source"]) for s in samples),
"net_input": { "net_input": {
"tokens": merge("source"), "src_tokens": merge("source"),
"segment_labels": merge("segment_labels"), "segment_labels": merge("segment_labels"),
}, },
"lm_target": merge("lm_target"), "lm_target": merge("lm_target"),
......
...@@ -99,8 +99,8 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -99,8 +99,8 @@ class MaskedLMModel(BaseFairseqModel):
help='Use gelu activation function in encoder' help='Use gelu activation function in encoder'
' layer') ' layer')
def forward(self, tokens, segment_labels): def forward(self, src_tokens, segment_labels, **kwargs):
return self.encoder(tokens, segment_labels) return self.encoder(src_tokens, segment_labels, **kwargs)
def max_positions(self): def max_positions(self):
return self.encoder.max_positions return self.encoder.max_positions
...@@ -109,10 +109,8 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -109,10 +109,8 @@ class MaskedLMModel(BaseFairseqModel):
def build_model(cls, args, task): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
if args.task == 'bert': # make sure all arguments are present in older models
base_bert_architecture(args) base_architecture(args)
else:
xlm_architecture(args)
if not hasattr(args, 'max_positions'): if not hasattr(args, 'max_positions'):
args.max_positions = args.tokens_per_sample args.max_positions = args.tokens_per_sample
...@@ -178,7 +176,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -178,7 +176,7 @@ class MaskedLMEncoder(FairseqEncoder):
bias=False bias=False
) )
def forward(self, tokens, segment_labels, **unused): def forward(self, src_tokens, segment_labels, **unused):
""" """
Forward pass for Masked LM encoder. This first computes the token Forward pass for Masked LM encoder. This first computes the token
embedding using the token embedding matrix, position embeddings (if embedding using the token embedding matrix, position embeddings (if
...@@ -188,7 +186,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -188,7 +186,7 @@ class MaskedLMEncoder(FairseqEncoder):
output of the classification_token (see bert_task or cross_lingual_lm output of the classification_token (see bert_task or cross_lingual_lm
task for more details). task for more details).
Args: Args:
- tokens: B x T matrix representing sentences - src_tokens: B x T matrix representing sentences
- segment_labels: B x T matrix representing segment label for tokens - segment_labels: B x T matrix representing segment label for tokens
Returns: Returns:
- a tuple of the following: - a tuple of the following:
...@@ -202,7 +200,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -202,7 +200,7 @@ class MaskedLMEncoder(FairseqEncoder):
this is specified in the input arguments. this is specified in the input arguments.
""" """
inner_states, sentence_rep = self.sentence_encoder(tokens, segment_labels) inner_states, sentence_rep = self.sentence_encoder(src_tokens, segment_labels)
x = inner_states[-1].transpose(0, 1) x = inner_states[-1].transpose(0, 1)
# project back to size of vocabulary # project back to size of vocabulary
...@@ -269,7 +267,7 @@ def base_architecture(args): ...@@ -269,7 +267,7 @@ def base_architecture(args):
@register_model_architecture('masked_lm', 'bert_base') @register_model_architecture('masked_lm', 'bert_base')
def base_bert_architecture(args): def bert_base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.share_encoder_input_output_embed = getattr( args.share_encoder_input_output_embed = getattr(
args, 'share_encoder_input_output_embed', True) args, 'share_encoder_input_output_embed', True)
...@@ -285,7 +283,7 @@ def base_bert_architecture(args): ...@@ -285,7 +283,7 @@ def base_bert_architecture(args):
args.no_bias_kv = getattr(args, 'no_bias_kv', True) args.no_bias_kv = getattr(args, 'no_bias_kv', True)
args.sent_loss = getattr(args, 'sent_loss', True) args.sent_loss = getattr(args, 'sent_loss', True)
args.sentence_class_num = getattr(args, 'sentence-class-num', 2) args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
args.apply_bert_init = getattr(args, 'apply_bert_init', True) args.apply_bert_init = getattr(args, 'apply_bert_init', True)
...@@ -303,7 +301,7 @@ def bert_large_architecture(args): ...@@ -303,7 +301,7 @@ def bert_large_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 24) args.encoder_layers = getattr(args, 'encoder_layers', 24)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
base_bert_architecture(args) bert_base_architecture(args)
@register_model_architecture('masked_lm', 'xlm_base') @register_model_architecture('masked_lm', 'xlm_base')
......
...@@ -171,7 +171,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -171,7 +171,7 @@ class TransformerSentenceEncoder(nn.Module):
# account for padding while computing the representation # account for padding while computing the representation
if padding_mask is not None: if padding_mask is not None:
x *= (1 - padding_mask.unsqueeze(-1).float()) x *= (1 - padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment