Commit f040158a authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Add Transformer Sentence Encoder for BERT and XLM Pre-training in PyText (#621)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/621

In this commit, I add some modules to Fairseq needed to set up Bert/XLM style pretraining.

Reviewed By: borguz

Differential Revision: D14719663

fbshipit-source-id: 1c5c36b6b2cde1c9bcd3c9e9ac853d2b7ae64102
parent 3658fa32
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
from .adaptive_input import AdaptiveInput from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .bert_layer_norm import BertLayerNorm
from .character_token_embedder import CharacterTokenEmbedder from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .downsampled_multihead_attention import DownsampledMultiHeadAttention
...@@ -23,12 +24,15 @@ from .mean_pool_gating_network import MeanPoolGatingNetwork ...@@ -23,12 +24,15 @@ from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
from .transformer_sentence_encoder import TransformerSentenceEncoder
from .unfold import unfold1d from .unfold import unfold1d
__all__ = [ __all__ = [
'AdaptiveInput', 'AdaptiveInput',
'AdaptiveSoftmax', 'AdaptiveSoftmax',
'BeamableMM', 'BeamableMM',
'BertLayerNorm',
'CharacterTokenEmbedder', 'CharacterTokenEmbedder',
'ConvTBC', 'ConvTBC',
'DownsampledMultiHeadAttention', 'DownsampledMultiHeadAttention',
...@@ -44,5 +48,7 @@ __all__ = [ ...@@ -44,5 +48,7 @@ __all__ = [
'MultiheadAttention', 'MultiheadAttention',
'ScalarBias', 'ScalarBias',
'SinusoidalPositionalEmbedding', 'SinusoidalPositionalEmbedding',
'TransformerSentenceEncoderLayer',
'TransformerSentenceEncoder',
'unfold1d', 'unfold1d',
] ]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""
Construct a layernorm module in the TF style used with BERT
(epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from fairseq.modules import (
MultiheadAttention, LearnedPositionalEmbedding, TransformerSentenceEncoderLayer
)
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
left_pad: bool
)-> nn.Embedding:
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
class TransformerSentenceEncoder(nn.Module):
"""
Implementation for a Bi-directional Transformer based Sentence Encoder used
in BERT/XLM style pre-trained models.
This first computes the token embedding using the token embedding matrix,
position embeddings (if specified) and segment embeddings
(if specified). After applying the specified number of
TransformerEncoderLayers, it outputs all the internal states of the
encoder as well as the final representation associated with the first
token (usually CLS token).
Input:
- tokens: B x T matrix representing sentences
- segment_labels: B x T matrix representing segment label for tokens
Output:
- a tuple of the following:
- a list of internal model states used to compute the
predictions where each tensor has shape B x T x C
- sentence representation associated with first input token
in format B x C.
"""
def __init__(
self,
padding_idx: int,
vocab_size: int,
num_encoder_layers: int = 6,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
max_seq_len: int = 256,
num_segments: int = 2,
use_position_embeddings: bool = True,
encoder_normalize_before: bool = False,
use_bert_layer_norm: bool = False,
use_gelu: bool = True,
apply_bert_init: bool = False,
) -> None:
super().__init__()
self.padding_idx = padding_idx
self.vocab_size = vocab_size
self.dropout = dropout
self.max_seq_len = max_seq_len
self.embedding_dim = embedding_dim
self.num_segments = num_segments
self.use_position_embeddings = use_position_embeddings
self.apply_bert_init = apply_bert_init
self.token_embeddings = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx
)
self.segment_embeddings = (
nn.Embedding(self.num_segments, self.embedding_dim, self.padding_idx)
if self.num_segments > 0
else None
)
self.position_embeddings = (
PositionalEmbedding(
self.max_seq_len,
self.embedding_dim,
self.padding_idx,
left_pad=False,
)
if self.use_position_embeddings
else None
)
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=self.dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
encoder_normalize_before=encoder_normalize_before,
use_bert_layer_norm=use_bert_layer_norm,
use_gelu=use_gelu,
)
for _ in range(num_encoder_layers)
]
)
# Apply initialization of model params after building the model
if self.apply_bert_init:
self.apply(init_bert_params)
def forward(
self,
tokens: torch.Tensor,
segment_labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# compute padding mask. This is needed for multi-head attention
padding_mask = tokens.eq(self.padding_idx)
if not padding_mask.any():
padding_mask = None
# embed positions
positions = (
self.position_embeddings(tokens)
if self.position_embeddings is not None else None
)
# embed segments
segments = (
self.segment_embeddings(segment_labels)
if self.segment_embeddings is not None
else None
)
x = self.token_embeddings(tokens)
if positions is not None:
x += positions
if segments is not None:
x += segments
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
inner_states = [x]
for layer in self.layers:
x, _ = layer(
x,
self_attn_padding_mask=padding_mask,
)
inner_states.append(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
sentence_rep = x[:, 0, :]
return inner_states, sentence_rep
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import MultiheadAttention, BertLayerNorm
def gelu(x: torch.Tensor) -> torch.Tensor:
"""
Implementation of the gelu activation function.
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
If the flag use_bert_layer_norm is set then we use the custom
BertLayerNorm module instead of nn.LayerNorm.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
encoder_normalize_before: bool = True,
use_bert_layer_norm: bool = True,
use_gelu: bool = True,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
self.normalize_before = encoder_normalize_before
# Initialize blocks
self.activation_fn = gelu if use_gelu else F.relu
self.self_attention = MultiheadAttention(
self.embedding_dim, num_attention_heads, dropout=attention_dropout
)
# layer norm associated with the self attention layer
self.self_attn_layer_norm = (
BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm
else nn.LayerNorm(self.embedding_dim, eps=1e-12)
)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = (
BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm
else nn.LayerNorm(self.embedding_dim, eps=1e-12)
)
def _maybe_layer_norm(
self,
layer_norm: nn.Module,
x: torch.Tensor,
before: bool = False,
after: bool = False,
):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
x = self._maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, attn = self.self_attention(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self._maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x
x = self._maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self._maybe_layer_norm(self.final_layer_norm, x, after=True)
return x, attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment