Commit bf106796 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Various fixes for Masked LM (#573)

Summary:
Various fixes for Masked LM

- use --activation-fn instead of --gelu
- use --dataset-impl instead of --lazy-load
- add embed_scale option to TransformerSentenceEncoder
- fix encoder_normalize_before to include a final layer norm
- delete BertLayerNorm
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/573

Reviewed By: borguz

Differential Revision: D15317933

Pulled By: myleott

fbshipit-source-id: 8ecb46556ad43e76e92d41ed8f5a62e8516fd375
parent 7432130e
...@@ -17,7 +17,7 @@ from fairseq.models import ( ...@@ -17,7 +17,7 @@ from fairseq.models import (
) )
from fairseq.modules import ( from fairseq.modules import (
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
TransformerSentenceEncoder TransformerSentenceEncoder,
) )
from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.modules.transformer_sentence_encoder import init_bert_params
...@@ -89,16 +89,11 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -89,16 +89,11 @@ class MaskedLMModel(BaseFairseqModel):
parser.add_argument('--apply-bert-init', action='store_true', parser.add_argument('--apply-bert-init', action='store_true',
help='use custom param initialization for BERT') help='use custom param initialization for BERT')
# layer norm layers
parser.add_argument('--bert-layer-norm', action='store_true',
help='use custom Layer Norm module for BERT')
# misc params # misc params
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'],
help='Which activation function to use')
parser.add_argument('--encoder-normalize-before', action='store_true', parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
parser.add_argument('--gelu', action='store_true',
help='Use gelu activation function in encoder'
' layer')
def forward(self, src_tokens, segment_labels, **kwargs): def forward(self, src_tokens, segment_labels, **kwargs):
return self.encoder(src_tokens, segment_labels, **kwargs) return self.encoder(src_tokens, segment_labels, **kwargs)
...@@ -148,9 +143,8 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -148,9 +143,8 @@ class MaskedLMEncoder(FairseqEncoder):
num_segments=args.num_segment, num_segments=args.num_segment,
use_position_embeddings=not args.no_token_positional_embeddings, use_position_embeddings=not args.no_token_positional_embeddings,
encoder_normalize_before=args.encoder_normalize_before, encoder_normalize_before=args.encoder_normalize_before,
use_bert_layer_norm=args.bert_layer_norm,
use_gelu=args.gelu,
apply_bert_init=args.apply_bert_init, apply_bert_init=args.apply_bert_init,
activation_fn=args.activation_fn,
learned_pos_embedding=args.encoder_learned_pos, learned_pos_embedding=args.encoder_learned_pos,
add_bias_kv=args.bias_kv, add_bias_kv=args.bias_kv,
add_zero_attn=args.zero_attn, add_zero_attn=args.zero_attn,
...@@ -263,11 +257,9 @@ def base_architecture(args): ...@@ -263,11 +257,9 @@ def base_architecture(args):
args.sent_loss = getattr(args, 'sent_loss', False) args.sent_loss = getattr(args, 'sent_loss', False)
args.apply_bert_init = getattr(args, 'apply_bert_init', False) args.apply_bert_init = getattr(args, 'apply_bert_init', False)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.encoder_normalize_before = getattr( args.activation_fn = getattr(args, 'activation_fn', 'relu')
args, 'encoder_normalize_before', False) args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.gelu = getattr(args, 'gelu', False)
@register_model_architecture('masked_lm', 'bert_base') @register_model_architecture('masked_lm', 'bert_base')
...@@ -287,16 +279,13 @@ def bert_base_architecture(args): ...@@ -287,16 +279,13 @@ def bert_base_architecture(args):
args.bias_kv = getattr(args, 'bias_kv', False) args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False) args.zero_attn = getattr(args, 'zero_attn', False)
args.sent_loss = getattr(args, 'sent_loss', True)
args.sentence_class_num = getattr(args, 'sentence_class_num', 2) args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
args.sent_loss = getattr(args, 'sent_loss', True)
args.apply_bert_init = getattr(args, 'apply_bert_init', True) args.apply_bert_init = getattr(args, 'apply_bert_init', True)
# TODO: validate setups for layernorm args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.encoder_normalize_before = getattr( args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
args, 'encoder_normalize_before', True)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', True)
args.gelu = getattr(args, 'gelu', True)
base_architecture(args) base_architecture(args)
...@@ -328,9 +317,6 @@ def xlm_architecture(args): ...@@ -328,9 +317,6 @@ def xlm_architecture(args):
args.sent_loss = getattr(args, 'sent_loss', False) args.sent_loss = getattr(args, 'sent_loss', False)
args.encoder_normalize_before = getattr( args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args, 'encoder_normalize_before', False) args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.gelu = getattr(args, 'gelu', True)
args.apply_bert_init = getattr(args, 'apply_bert_init', True)
base_architecture(args) base_architecture(args)
...@@ -5,13 +5,17 @@ ...@@ -5,13 +5,17 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Tuple
from fairseq.modules import ( from fairseq.modules import (
MultiheadAttention, PositionalEmbedding, TransformerSentenceEncoderLayer LayerNorm,
MultiheadAttention,
PositionalEmbedding,
TransformerSentenceEncoderLayer,
) )
...@@ -78,12 +82,12 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -78,12 +82,12 @@ class TransformerSentenceEncoder(nn.Module):
num_segments: int = 2, num_segments: int = 2,
use_position_embeddings: bool = True, use_position_embeddings: bool = True,
encoder_normalize_before: bool = False, encoder_normalize_before: bool = False,
use_bert_layer_norm: bool = False,
use_gelu: bool = True,
apply_bert_init: bool = False, apply_bert_init: bool = False,
activation_fn: str = 'relu',
learned_pos_embedding: bool = True, learned_pos_embedding: bool = True,
add_bias_kv: bool = False, add_bias_kv: bool = False,
add_zero_attn: bool = False, add_zero_attn: bool = False,
embed_scale: float = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -98,8 +102,9 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -98,8 +102,9 @@ class TransformerSentenceEncoder(nn.Module):
self.learned_pos_embedding = learned_pos_embedding self.learned_pos_embedding = learned_pos_embedding
self.embed_tokens = nn.Embedding( self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx self.vocab_size, self.embedding_dim, self.padding_idx,
) )
self.embed_scale = embed_scale
self.segment_embeddings = ( self.segment_embeddings = (
nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None) nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None)
...@@ -127,9 +132,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -127,9 +132,7 @@ class TransformerSentenceEncoder(nn.Module):
dropout=self.dropout, dropout=self.dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
activation_dropout=activation_dropout, activation_dropout=activation_dropout,
encoder_normalize_before=encoder_normalize_before, activation_fn=activation_fn,
use_bert_layer_norm=use_bert_layer_norm,
use_gelu=use_gelu,
add_bias_kv=add_bias_kv, add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn, add_zero_attn=add_zero_attn,
) )
...@@ -137,6 +140,11 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -137,6 +140,11 @@ class TransformerSentenceEncoder(nn.Module):
] ]
) )
if encoder_normalize_before:
self.emb_layer_norm = LayerNorm(self.embedding_dim)
else:
self.emb_layer_norm = None
# Apply initialization of model params after building the model # Apply initialization of model params after building the model
if self.apply_bert_init: if self.apply_bert_init:
self.apply(init_bert_params) self.apply(init_bert_params)
...@@ -152,30 +160,24 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -152,30 +160,24 @@ class TransformerSentenceEncoder(nn.Module):
if not padding_mask.any(): if not padding_mask.any():
padding_mask = None padding_mask = None
# embed positions x = self.embed_tokens(tokens)
positions = ( if self.embed_scale is not None:
self.embed_positions(tokens) x *= self.embed_scale
if self.embed_positions is not None else None
)
# embed segments if self.embed_positions is not None:
segments = ( x += self.embed_positions(tokens)
self.segment_embeddings(segment_labels)
if self.segment_embeddings is not None
else None
)
x = self.embed_tokens(tokens) if self.segment_embeddings is not None and segment_labels is not None:
x += self.segment_embeddings(segment_labels)
if self.emb_layer_norm is not None:
x = self.emb_layer_norm(x)
if positions is not None:
x += positions
if segments is not None:
x += segments
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
# account for padding while computing the representation # account for padding while computing the representation
if padding_mask is not None: if padding_mask is not None:
x *= (1 - padding_mask.unsqueeze(-1).type_as(x)) x *= (~padding_mask).unsqueeze(-1).type_as(x)
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
......
...@@ -9,9 +9,8 @@ import torch ...@@ -9,9 +9,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
BertLayerNorm,
gelu,
LayerNorm, LayerNorm,
MultiheadAttention, MultiheadAttention,
) )
...@@ -21,9 +20,6 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -21,9 +20,6 @@ class TransformerSentenceEncoderLayer(nn.Module):
""" """
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models. models.
If the flag use_bert_layer_norm is set then we use the custom
BertLayerNorm module instead of LayerNorm.
""" """
def __init__( def __init__(
...@@ -34,9 +30,7 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -34,9 +30,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
activation_dropout: float = 0.1, activation_dropout: float = 0.1,
encoder_normalize_before: bool = False, activation_fn: str = 'relu',
use_bert_layer_norm: bool = False,
use_gelu: bool = True,
add_bias_kv: bool = False, add_bias_kv: bool = False,
add_zero_attn: bool = False, add_zero_attn: bool = False,
) -> None: ) -> None:
...@@ -46,10 +40,9 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -46,10 +40,9 @@ class TransformerSentenceEncoderLayer(nn.Module):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.dropout = dropout self.dropout = dropout
self.activation_dropout = activation_dropout self.activation_dropout = activation_dropout
self.normalize_before = encoder_normalize_before
# Initialize blocks # Initialize blocks
self.activation_fn = gelu if use_gelu else F.relu self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention( self.self_attn = MultiheadAttention(
self.embedding_dim, self.embedding_dim,
num_attention_heads, num_attention_heads,
...@@ -59,33 +52,12 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -59,33 +52,12 @@ class TransformerSentenceEncoderLayer(nn.Module):
) )
# layer norm associated with the self attention layer # layer norm associated with the self attention layer
self.self_attn_layer_norm = ( self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm
else LayerNorm(self.embedding_dim, eps=1e-12)
)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN # layer norm associated with the position wise feed-forward NN
self.final_layer_norm = ( self.final_layer_norm = LayerNorm(self.embedding_dim)
BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm
else LayerNorm(self.embedding_dim, eps=1e-12)
)
def _maybe_layer_norm(
self,
layer_norm: nn.Module,
x: torch.Tensor,
before: bool = False,
after: bool = False,
):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def forward( def forward(
self, self,
...@@ -97,9 +69,7 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -97,9 +69,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
LayerNorm is applied either before or after the self-attention/ffn LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation. modules similar to the original Transformer imlementation.
""" """
residual = x residual = x
x = self._maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, attn = self.self_attn( x, attn = self.self_attn(
query=x, query=x,
key=x, key=x,
...@@ -110,14 +80,13 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -110,14 +80,13 @@ class TransformerSentenceEncoderLayer(nn.Module):
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self._maybe_layer_norm(self.self_attn_layer_norm, x, after=True) x = self.self_attn_layer_norm(x)
residual = x residual = x
x = self._maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x)) x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training) x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self._maybe_layer_norm(self.final_layer_norm, x, after=True) x = self.final_layer_norm(x)
return x, attn return x, attn
...@@ -12,9 +12,7 @@ import os ...@@ -12,9 +12,7 @@ import os
from fairseq import tokenizer from fairseq import tokenizer
from fairseq.data import ( from fairseq.data import (
ConcatDataset, ConcatDataset,
IndexedCachedDataset, indexed_dataset,
IndexedDataset,
IndexedRawTextDataset,
data_utils, data_utils,
) )
...@@ -42,10 +40,7 @@ class MaskedLMTask(FairseqTask): ...@@ -42,10 +40,7 @@ class MaskedLMTask(FairseqTask):
parser.add_argument('--tokens-per-sample', default=512, type=int, parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments' help='max number of total tokens over all segments'
' per sample for BERT dataset') ' per sample for BERT dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--break-mode', default="doc", type=str, help='mode for breaking sentence') parser.add_argument('--break-mode', default="doc", type=str, help='mode for breaking sentence')
parser.add_argument('--lazy-load', action='store_true', help='load the dataset lazily')
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
super().__init__(args) super().__init__(args)
...@@ -94,19 +89,19 @@ class MaskedLMTask(FairseqTask): ...@@ -94,19 +89,19 @@ class MaskedLMTask(FairseqTask):
for k in itertools.count(): for k in itertools.count():
split_k = split + (str(k) if k > 0 else '') split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k) path = os.path.join(data_path, split_k)
ds = indexed_dataset.make_dataset(
if self.args.raw_text and IndexedRawTextDataset.exists(path): path,
ds = IndexedRawTextDataset(path, self.dictionary) impl=self.args.dataset_impl,
elif not self.args.raw_text and IndexedDataset.exists(path): fix_lua_indexing=True,
if self.args.lazy_load: dictionary=self.dictionary,
ds = IndexedDataset(path, fix_lua_indexing=True) )
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True) if ds is None:
else:
if k > 0: if k > 0:
break break
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
with data_utils.numpy_seed(self.seed + k): with data_utils.numpy_seed(self.seed + k):
loaded_datasets.append( loaded_datasets.append(
BlockPairDataset( BlockPairDataset(
...@@ -116,7 +111,8 @@ class MaskedLMTask(FairseqTask): ...@@ -116,7 +111,8 @@ class MaskedLMTask(FairseqTask):
self.args.tokens_per_sample, self.args.tokens_per_sample,
break_mode=self.args.break_mode, break_mode=self.args.break_mode,
doc_break_size=1, doc_break_size=1,
)) )
)
print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment