Commit d1d3a581 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

added missing dense layers in masked lm model (#581)

Summary:
1) Added pooled_output for sentence classification as `Tanh(Linear())`.
2) Added lm_head_transform as `LayerNorm(GeLU(Linear(x)))`
3) `act_dropout = 0.0`
4) added `lm_output_learned_bias`
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/581

Reviewed By: borguz

Differential Revision: D15353575

Pulled By: borguz

fbshipit-source-id: 4ff64c6ceed23f3e99348f73d189546f1d84452e
parent dffb1674
......@@ -9,6 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
BaseFairseqModel,
FairseqEncoder,
......@@ -16,6 +17,7 @@ from fairseq.models import (
register_model_architecture,
)
from fairseq.modules import (
LayerNorm,
SinusoidalPositionalEmbedding,
TransformerSentenceEncoder,
)
......@@ -92,6 +94,9 @@ class MaskedLMModel(BaseFairseqModel):
# misc params
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'],
help='Which activation function to use')
parser.add_argument('--pooler-activation-fn',
choices=['relu', 'gelu', 'gelu_accurate', 'tanh'],
help='Which activation function to use for pooler layer.')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
......@@ -158,7 +163,18 @@ class MaskedLMEncoder(FairseqEncoder):
# Remove head is set to true during fine-tuning
self.load_softmax = not getattr(args, 'remove_head', False)
self.masked_lm_pooler = nn.Linear(
args.encoder_embed_dim, args.encoder_embed_dim
)
self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn)
self.lm_head_transform_weight = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim)
self.activation_fn = utils.get_activation_fn(args.activation_fn)
self.layer_norm = LayerNorm(args.encoder_embed_dim)
if self.load_softmax:
self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size))
if not self.share_input_output_embed:
self.embed_out = nn.Linear(
args.encoder_embed_dim,
......@@ -189,7 +205,7 @@ class MaskedLMEncoder(FairseqEncoder):
- a tuple of the following:
- logits for predictions in format B x T x C to be used in
softmax afterwards
- a dictionary of additional data, where 'sentence_rep' contains
- a dictionary of additional data, where 'pooled_output' contains
the representation for classification_token and 'inner_states'
is a list of internal model states used to compute the
predictions (similar in ELMO). 'sentence_logits'
......@@ -198,7 +214,11 @@ class MaskedLMEncoder(FairseqEncoder):
"""
inner_states, sentence_rep = self.sentence_encoder(src_tokens, segment_labels)
x = inner_states[-1].transpose(0, 1)
x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep))
# project back to size of vocabulary
if self.share_input_output_embed \
......@@ -206,13 +226,15 @@ class MaskedLMEncoder(FairseqEncoder):
x = F.linear(x, self.sentence_encoder.embed_tokens.weight)
elif self.embed_out is not None:
x = self.embed_out(x)
x = x + self.lm_output_learned_bias
sentence_logits = None
if self.sentence_projection_layer:
sentence_logits = self.sentence_projection_layer(sentence_rep)
sentence_logits = self.sentence_projection_layer(pooled_output)
return x, {
'inner_states': inner_states,
'sentence_rep': sentence_rep,
'pooled_output': pooled_output,
'sentence_logits': sentence_logits
}
......@@ -239,7 +261,7 @@ class MaskedLMEncoder(FairseqEncoder):
def base_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.act_dropout = getattr(args, 'act_dropout', 0.1)
args.act_dropout = getattr(args, 'act_dropout', 0.0)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
......@@ -259,6 +281,7 @@ def base_architecture(args):
args.apply_bert_init = getattr(args, 'apply_bert_init', False)
args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
......@@ -285,6 +308,7 @@ def bert_base_architecture(args):
args.apply_bert_init = getattr(args, 'apply_bert_init', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
base_architecture(args)
......@@ -319,4 +343,5 @@ def xlm_architecture(args):
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
base_architecture(args)
......@@ -310,5 +310,7 @@ def get_activation_fn(activation: str) -> Callable:
return gelu_accurate
elif activation == 'gelu_accurate':
return gelu_accurate
elif activation == 'tanh':
return F.tanh
else:
raise RuntimeError(f"--activation-fn {activation} not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment