Commit bf106796 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Various fixes for Masked LM (#573)

Summary:
Various fixes for Masked LM

- use --activation-fn instead of --gelu
- use --dataset-impl instead of --lazy-load
- add embed_scale option to TransformerSentenceEncoder
- fix encoder_normalize_before to include a final layer norm
- delete BertLayerNorm
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/573

Reviewed By: borguz

Differential Revision: D15317933

Pulled By: myleott

fbshipit-source-id: 8ecb46556ad43e76e92d41ed8f5a62e8516fd375
parent 7432130e
......@@ -17,7 +17,7 @@ from fairseq.models import (
)
from fairseq.modules import (
SinusoidalPositionalEmbedding,
TransformerSentenceEncoder
TransformerSentenceEncoder,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
......@@ -89,16 +89,11 @@ class MaskedLMModel(BaseFairseqModel):
parser.add_argument('--apply-bert-init', action='store_true',
help='use custom param initialization for BERT')
# layer norm layers
parser.add_argument('--bert-layer-norm', action='store_true',
help='use custom Layer Norm module for BERT')
# misc params
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'],
help='Which activation function to use')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--gelu', action='store_true',
help='Use gelu activation function in encoder'
' layer')
def forward(self, src_tokens, segment_labels, **kwargs):
return self.encoder(src_tokens, segment_labels, **kwargs)
......@@ -148,9 +143,8 @@ class MaskedLMEncoder(FairseqEncoder):
num_segments=args.num_segment,
use_position_embeddings=not args.no_token_positional_embeddings,
encoder_normalize_before=args.encoder_normalize_before,
use_bert_layer_norm=args.bert_layer_norm,
use_gelu=args.gelu,
apply_bert_init=args.apply_bert_init,
activation_fn=args.activation_fn,
learned_pos_embedding=args.encoder_learned_pos,
add_bias_kv=args.bias_kv,
add_zero_attn=args.zero_attn,
......@@ -263,11 +257,9 @@ def base_architecture(args):
args.sent_loss = getattr(args, 'sent_loss', False)
args.apply_bert_init = getattr(args, 'apply_bert_init', False)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.encoder_normalize_before = getattr(
args, 'encoder_normalize_before', False)
args.gelu = getattr(args, 'gelu', False)
args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
@register_model_architecture('masked_lm', 'bert_base')
......@@ -287,16 +279,13 @@ def bert_base_architecture(args):
args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False)
args.sent_loss = getattr(args, 'sent_loss', True)
args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
args.sent_loss = getattr(args, 'sent_loss', True)
args.apply_bert_init = getattr(args, 'apply_bert_init', True)
# TODO: validate setups for layernorm
args.encoder_normalize_before = getattr(
args, 'encoder_normalize_before', True)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', True)
args.gelu = getattr(args, 'gelu', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
base_architecture(args)
......@@ -328,9 +317,6 @@ def xlm_architecture(args):
args.sent_loss = getattr(args, 'sent_loss', False)
args.encoder_normalize_before = getattr(
args, 'encoder_normalize_before', False)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.gelu = getattr(args, 'gelu', True)
args.apply_bert_init = getattr(args, 'apply_bert_init', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
base_architecture(args)
......@@ -5,13 +5,17 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from fairseq.modules import (
MultiheadAttention, PositionalEmbedding, TransformerSentenceEncoderLayer
LayerNorm,
MultiheadAttention,
PositionalEmbedding,
TransformerSentenceEncoderLayer,
)
......@@ -78,12 +82,12 @@ class TransformerSentenceEncoder(nn.Module):
num_segments: int = 2,
use_position_embeddings: bool = True,
encoder_normalize_before: bool = False,
use_bert_layer_norm: bool = False,
use_gelu: bool = True,
apply_bert_init: bool = False,
activation_fn: str = 'relu',
learned_pos_embedding: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
embed_scale: float = None,
) -> None:
super().__init__()
......@@ -98,8 +102,9 @@ class TransformerSentenceEncoder(nn.Module):
self.learned_pos_embedding = learned_pos_embedding
self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx
self.vocab_size, self.embedding_dim, self.padding_idx,
)
self.embed_scale = embed_scale
self.segment_embeddings = (
nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None)
......@@ -127,9 +132,7 @@ class TransformerSentenceEncoder(nn.Module):
dropout=self.dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
encoder_normalize_before=encoder_normalize_before,
use_bert_layer_norm=use_bert_layer_norm,
use_gelu=use_gelu,
activation_fn=activation_fn,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
......@@ -137,6 +140,11 @@ class TransformerSentenceEncoder(nn.Module):
]
)
if encoder_normalize_before:
self.emb_layer_norm = LayerNorm(self.embedding_dim)
else:
self.emb_layer_norm = None
# Apply initialization of model params after building the model
if self.apply_bert_init:
self.apply(init_bert_params)
......@@ -152,30 +160,24 @@ class TransformerSentenceEncoder(nn.Module):
if not padding_mask.any():
padding_mask = None
# embed positions
positions = (
self.embed_positions(tokens)
if self.embed_positions is not None else None
)
x = self.embed_tokens(tokens)
if self.embed_scale is not None:
x *= self.embed_scale
# embed segments
segments = (
self.segment_embeddings(segment_labels)
if self.segment_embeddings is not None
else None
)
if self.embed_positions is not None:
x += self.embed_positions(tokens)
x = self.embed_tokens(tokens)
if self.segment_embeddings is not None and segment_labels is not None:
x += self.segment_embeddings(segment_labels)
if self.emb_layer_norm is not None:
x = self.emb_layer_norm(x)
if positions is not None:
x += positions
if segments is not None:
x += segments
x = F.dropout(x, p=self.dropout, training=self.training)
# account for padding while computing the representation
if padding_mask is not None:
x *= (1 - padding_mask.unsqueeze(-1).type_as(x))
x *= (~padding_mask).unsqueeze(-1).type_as(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
......
......@@ -9,9 +9,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import (
BertLayerNorm,
gelu,
LayerNorm,
MultiheadAttention,
)
......@@ -21,9 +20,6 @@ class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
If the flag use_bert_layer_norm is set then we use the custom
BertLayerNorm module instead of LayerNorm.
"""
def __init__(
......@@ -34,9 +30,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
encoder_normalize_before: bool = False,
use_bert_layer_norm: bool = False,
use_gelu: bool = True,
activation_fn: str = 'relu',
add_bias_kv: bool = False,
add_zero_attn: bool = False,
) -> None:
......@@ -46,10 +40,9 @@ class TransformerSentenceEncoderLayer(nn.Module):
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
self.normalize_before = encoder_normalize_before
# Initialize blocks
self.activation_fn = gelu if use_gelu else F.relu
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
......@@ -59,33 +52,12 @@ class TransformerSentenceEncoderLayer(nn.Module):
)
# layer norm associated with the self attention layer
self.self_attn_layer_norm = (
BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm
else LayerNorm(self.embedding_dim, eps=1e-12)
)
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = (
BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm
else LayerNorm(self.embedding_dim, eps=1e-12)
)
def _maybe_layer_norm(
self,
layer_norm: nn.Module,
x: torch.Tensor,
before: bool = False,
after: bool = False,
):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
self.final_layer_norm = LayerNorm(self.embedding_dim)
def forward(
self,
......@@ -97,9 +69,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
x = self._maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, attn = self.self_attn(
query=x,
key=x,
......@@ -110,14 +80,13 @@ class TransformerSentenceEncoderLayer(nn.Module):
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self._maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
x = self.self_attn_layer_norm(x)
residual = x
x = self._maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self._maybe_layer_norm(self.final_layer_norm, x, after=True)
x = self.final_layer_norm(x)
return x, attn
......@@ -12,9 +12,7 @@ import os
from fairseq import tokenizer
from fairseq.data import (
ConcatDataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
indexed_dataset,
data_utils,
)
......@@ -42,10 +40,7 @@ class MaskedLMTask(FairseqTask):
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments'
' per sample for BERT dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--break-mode', default="doc", type=str, help='mode for breaking sentence')
parser.add_argument('--lazy-load', action='store_true', help='load the dataset lazily')
def __init__(self, args, dictionary):
super().__init__(args)
......@@ -94,19 +89,19 @@ class MaskedLMTask(FairseqTask):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k)
ds = indexed_dataset.make_dataset(
path,
impl=self.args.dataset_impl,
fix_lua_indexing=True,
dictionary=self.dictionary,
)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path):
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
if ds is None:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
with data_utils.numpy_seed(self.seed + k):
loaded_datasets.append(
BlockPairDataset(
......@@ -116,7 +111,8 @@ class MaskedLMTask(FairseqTask):
self.args.tokens_per_sample,
break_mode=self.args.break_mode,
doc_break_size=1,
))
)
)
print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment