Commit d1d3a581 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

added missing dense layers in masked lm model (#581)

Summary:
1) Added pooled_output for sentence classification as `Tanh(Linear())`.
2) Added lm_head_transform as `LayerNorm(GeLU(Linear(x)))`
3) `act_dropout = 0.0`
4) added `lm_output_learned_bias`
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/581

Reviewed By: borguz

Differential Revision: D15353575

Pulled By: borguz

fbshipit-source-id: 4ff64c6ceed23f3e99348f73d189546f1d84452e
parent dffb1674
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from fairseq.models import ( from fairseq.models import (
BaseFairseqModel, BaseFairseqModel,
FairseqEncoder, FairseqEncoder,
...@@ -16,6 +17,7 @@ from fairseq.models import ( ...@@ -16,6 +17,7 @@ from fairseq.models import (
register_model_architecture, register_model_architecture,
) )
from fairseq.modules import ( from fairseq.modules import (
LayerNorm,
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
TransformerSentenceEncoder, TransformerSentenceEncoder,
) )
...@@ -92,6 +94,9 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -92,6 +94,9 @@ class MaskedLMModel(BaseFairseqModel):
# misc params # misc params
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'], parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'],
help='Which activation function to use') help='Which activation function to use')
parser.add_argument('--pooler-activation-fn',
choices=['relu', 'gelu', 'gelu_accurate', 'tanh'],
help='Which activation function to use for pooler layer.')
parser.add_argument('--encoder-normalize-before', action='store_true', parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
...@@ -158,7 +163,18 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -158,7 +163,18 @@ class MaskedLMEncoder(FairseqEncoder):
# Remove head is set to true during fine-tuning # Remove head is set to true during fine-tuning
self.load_softmax = not getattr(args, 'remove_head', False) self.load_softmax = not getattr(args, 'remove_head', False)
self.masked_lm_pooler = nn.Linear(
args.encoder_embed_dim, args.encoder_embed_dim
)
self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn)
self.lm_head_transform_weight = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim)
self.activation_fn = utils.get_activation_fn(args.activation_fn)
self.layer_norm = LayerNorm(args.encoder_embed_dim)
if self.load_softmax: if self.load_softmax:
self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size))
if not self.share_input_output_embed: if not self.share_input_output_embed:
self.embed_out = nn.Linear( self.embed_out = nn.Linear(
args.encoder_embed_dim, args.encoder_embed_dim,
...@@ -189,7 +205,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -189,7 +205,7 @@ class MaskedLMEncoder(FairseqEncoder):
- a tuple of the following: - a tuple of the following:
- logits for predictions in format B x T x C to be used in - logits for predictions in format B x T x C to be used in
softmax afterwards softmax afterwards
- a dictionary of additional data, where 'sentence_rep' contains - a dictionary of additional data, where 'pooled_output' contains
the representation for classification_token and 'inner_states' the representation for classification_token and 'inner_states'
is a list of internal model states used to compute the is a list of internal model states used to compute the
predictions (similar in ELMO). 'sentence_logits' predictions (similar in ELMO). 'sentence_logits'
...@@ -198,7 +214,11 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -198,7 +214,11 @@ class MaskedLMEncoder(FairseqEncoder):
""" """
inner_states, sentence_rep = self.sentence_encoder(src_tokens, segment_labels) inner_states, sentence_rep = self.sentence_encoder(src_tokens, segment_labels)
x = inner_states[-1].transpose(0, 1) x = inner_states[-1].transpose(0, 1)
x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep))
# project back to size of vocabulary # project back to size of vocabulary
if self.share_input_output_embed \ if self.share_input_output_embed \
...@@ -206,13 +226,15 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -206,13 +226,15 @@ class MaskedLMEncoder(FairseqEncoder):
x = F.linear(x, self.sentence_encoder.embed_tokens.weight) x = F.linear(x, self.sentence_encoder.embed_tokens.weight)
elif self.embed_out is not None: elif self.embed_out is not None:
x = self.embed_out(x) x = self.embed_out(x)
x = x + self.lm_output_learned_bias
sentence_logits = None sentence_logits = None
if self.sentence_projection_layer: if self.sentence_projection_layer:
sentence_logits = self.sentence_projection_layer(sentence_rep) sentence_logits = self.sentence_projection_layer(pooled_output)
return x, { return x, {
'inner_states': inner_states, 'inner_states': inner_states,
'sentence_rep': sentence_rep, 'pooled_output': pooled_output,
'sentence_logits': sentence_logits 'sentence_logits': sentence_logits
} }
...@@ -239,7 +261,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -239,7 +261,7 @@ class MaskedLMEncoder(FairseqEncoder):
def base_architecture(args): def base_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.act_dropout = getattr(args, 'act_dropout', 0.1) args.act_dropout = getattr(args, 'act_dropout', 0.0)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_layers = getattr(args, 'encoder_layers', 6) args.encoder_layers = getattr(args, 'encoder_layers', 6)
...@@ -259,6 +281,7 @@ def base_architecture(args): ...@@ -259,6 +281,7 @@ def base_architecture(args):
args.apply_bert_init = getattr(args, 'apply_bert_init', False) args.apply_bert_init = getattr(args, 'apply_bert_init', False)
args.activation_fn = getattr(args, 'activation_fn', 'relu') args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
...@@ -285,6 +308,7 @@ def bert_base_architecture(args): ...@@ -285,6 +308,7 @@ def bert_base_architecture(args):
args.apply_bert_init = getattr(args, 'apply_bert_init', True) args.apply_bert_init = getattr(args, 'apply_bert_init', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True) args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
base_architecture(args) base_architecture(args)
...@@ -319,4 +343,5 @@ def xlm_architecture(args): ...@@ -319,4 +343,5 @@ def xlm_architecture(args):
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
base_architecture(args) base_architecture(args)
...@@ -310,5 +310,7 @@ def get_activation_fn(activation: str) -> Callable: ...@@ -310,5 +310,7 @@ def get_activation_fn(activation: str) -> Callable:
return gelu_accurate return gelu_accurate
elif activation == 'gelu_accurate': elif activation == 'gelu_accurate':
return gelu_accurate return gelu_accurate
elif activation == 'tanh':
return F.tanh
else: else:
raise RuntimeError(f"--activation-fn {activation} not supported") raise RuntimeError(f"--activation-fn {activation} not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment