Commit 97b58b46 authored by Myle Ott's avatar Myle Ott
Browse files

Add Transformer model

parent 6a7c8d0d
...@@ -371,6 +371,8 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -371,6 +371,8 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
# sort by sizes # sort by sizes
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')] indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')] indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
max_sizes = np.maximum(dst.sizes[indices], src.sizes[indices])
indices = indices[np.argsort(max_sizes[indices], kind='mergesort')]
batches = list(_make_batches( batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions, src, dst, indices, max_tokens, max_sentences, max_positions,
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import LanguagePairDataset
from fairseq.modules import (
LayerNorm, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
)
from . import (
FairseqDecoder, FairseqEncoder, FairseqModel,
register_model, register_model_architecture,
)
@register_model('transformer')
class TransformerModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
if args.share_all_embeddings:
if src_dict != dst_dict:
raise RuntimeError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(dst_dict, args.decoder_embed_dim)
encoder = TransformerEncoder(
src_dict,
encoder_embed_tokens,
ffn_inner_dim=args.encoder_ffn_embed_dim,
num_layers=args.encoder_layers,
num_attn_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
relu_dropout=args.relu_dropout,
normalize_before=args.encoder_normalize_before,
learned_pos_embed=args.encoder_learned_pos,
)
decoder = TransformerDecoder(
dst_dict,
decoder_embed_tokens,
ffn_inner_dim=args.decoder_ffn_embed_dim,
num_layers=args.decoder_layers,
num_attn_heads=args.decoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
relu_dropout=args.relu_dropout,
normalize_before=args.encoder_normalize_before,
learned_pos_embed=args.decoder_learned_pos,
share_input_output_embed=args.share_decoder_input_output_embed,
)
return TransformerModel(encoder, decoder)
class TransformerEncoder(FairseqEncoder):
"""Transformer encoder."""
def __init__(
self, dictionary, embed_tokens, ffn_inner_dim=2048,
num_layers=6, num_attn_heads=8, dropout=0.1, attention_dropout=0.,
relu_dropout=0., normalize_before=False, learned_pos_embed=False,
):
super().__init__(dictionary)
self.dropout = dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
learned=learned_pos_embed,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(
embed_dim, ffn_inner_dim, num_attn_heads, dropout=dropout,
attention_dropout=attention_dropout, relu_dropout=relu_dropout,
normalize_before=normalize_before,
)
for i in range(num_layers)
])
self.reset_parameters()
def reset_parameters(self):
for name, p in self.named_parameters():
if name.endswith('weight'):
nn.init.xavier_uniform(p.data)
elif name.endswith('bias'):
p.data.zero_()
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
class TransformerDecoder(FairseqDecoder):
"""Transformer decoder."""
def __init__(
self, dictionary, embed_tokens, ffn_inner_dim=2048,
num_layers=6, num_attn_heads=8, dropout=0.1, attention_dropout=0.,
relu_dropout=0., normalize_before=False, learned_pos_embed=False,
share_input_output_embed=False,
):
super().__init__(dictionary)
self.dropout = dropout
self.share_input_output_embed = share_input_output_embed
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
learned=learned_pos_embed,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(
embed_dim, ffn_inner_dim, num_attn_heads, dropout=dropout,
attention_dropout=attention_dropout, relu_dropout=relu_dropout,
normalize_before=normalize_before,
)
for i in range(num_layers)
])
if not share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
self.reset_parameters()
def reset_parameters(self):
for name, p in self.named_parameters():
if name.endswith('weight'):
nn.init.xavier_uniform(p.data)
elif name.endswith('bias'):
p.data.zero_()
def forward(self, prev_output_tokens, encoder_out):
# embed positions
positions = self.embed_positions(prev_output_tokens)
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# decoder layers
for layer in self.layers:
x, attn = layer(x, encoder_out['encoder_out'], encoder_out['encoder_padding_mask'])
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
return x, attn
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(new_order)
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual.
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def __init__(
self, embed_dim, ffn_inner_dim, num_attn_heads, dropout=0.1,
attention_dropout=0., relu_dropout=0., normalize_before=False,
):
super().__init__()
self.embed_dim = embed_dim
self.self_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.dropout = dropout
self.relu_dropout = relu_dropout
self.normalize_before = normalize_before
self.fc1 = nn.Linear(embed_dim, ffn_inner_dim)
self.fc2 = nn.Linear(ffn_inner_dim, embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(embed_dim) for i in range(2)])
def forward(self, x, encoder_padding_mask):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
return x
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block."""
def __init__(
self, embed_dim, ffn_inner_dim, num_attn_heads, dropout=0.1,
attention_dropout=0., relu_dropout=0., normalize_before=False,
):
super().__init__()
self.embed_dim = embed_dim
self.self_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.dropout = dropout
self.relu_dropout = relu_dropout
self.normalize_before = normalize_before
self.encoder_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.fc1 = nn.Linear(embed_dim, ffn_inner_dim)
self.fc2 = nn.Linear(ffn_inner_dim, embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, mask_future_timesteps=True)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn(query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
residual = x
x = self.maybe_layer_norm(2, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(2, x, after=True)
return x, attn
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(mean=0, std=embedding_dim**-0.5)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
m.weight.data.normal_(0, 0.1)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
return m
@register_model_architecture('transformer', 'transformer')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
@register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args):
base_architecture(args)
args.encoder_embed_dim = 256
args.encoder_ffn_embed_dim = 512
args.encoder_layers = 3
args.encoder_attention_heads = 4
args.decoder_embed_dim = 256
args.decoder_ffn_embed_dim = 512
args.decoder_layers = 3
args.decoder_attention_heads = 4
@register_model_architecture('transformer', 'transformer_wmt_en_de')
def transformer_wmt_en_de(args):
base_architecture(args)
args.encoder_embed_dim = 512
args.encoder_ffn_embed_dim = 2048
args.encoder_layers = 6
args.encoder_attention_heads = 8
args.decoder_embed_dim = 512
args.decoder_ffn_embed_dim = 2048
args.decoder_layers = 6
args.decoder_attention_heads = 8
@register_model_architecture('transformer', 'transformer_wmt_en_de_big')
def transformer_wmt_en_de_big(args):
base_architecture(args)
args.encoder_embed_dim = 1024
args.encoder_ffn_embed_dim = 4096
args.encoder_layers = 6
args.encoder_attention_heads = 16
args.decoder_embed_dim = 1024
args.decoder_ffn_embed_dim = 4096
args.decoder_layers = 6
args.decoder_attention_heads = 16
@register_model_architecture('transformer', 'transformer_wmt_en_de_big_t2t')
def transformer_wmt_en_de_big_t2t(args):
base_architecture(args)
args.encoder_embed_dim = 1024
args.encoder_ffn_embed_dim = 4096
args.encoder_layers = 6
args.encoder_attention_heads = 16
args.encoder_normalize_before = True
args.decoder_embed_dim = 1024
args.decoder_ffn_embed_dim = 4096
args.decoder_layers = 6
args.decoder_attention_heads = 16
args.decoder_normalize_before = True
args.attention_dropout = 0.1
args.relu_dropout = 0.1
...@@ -8,13 +8,19 @@ ...@@ -8,13 +8,19 @@
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
from .layer_norm import LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [ __all__ = [
'BeamableMM', 'BeamableMM',
'ConvTBC', 'ConvTBC',
'GradMultiply', 'GradMultiply',
'LayerNorm',
'LearnedPositionalEmbedding', 'LearnedPositionalEmbedding',
'LinearizedConvolution', 'LinearizedConvolution',
'MultiheadAttention',
'SinusoidalPositionalEmbedding',
] ]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
"""Applies Layer Normalization over the last dimension."""
def __init__(self, features, eps=1e-5):
super().__init__()
self.features = features
self.eps = eps
self.gain = nn.Parameter(torch.ones(features))
self.bias = nn.Parameter(torch.zeros(features))
self.dummy = None
self.w = None
self.b = None
def forward(self, input):
shape = input.size()
# In order to force the cudnn path, everything needs to be
# contiguous. Hence the check here and reallocation below.
if not input.is_contiguous():
input = input.contiguous()
input = input.view(1, -1, shape[-1])
# Expand w and b buffers if necessary.
n = input.size(1)
cur = self.dummy.numel() if self.dummy is not None else 0
if cur == 0:
self.dummy = input.data.new(n)
self.w = input.data.new(n).fill_(1)
self.b = input.data.new(n).zero_()
elif n > cur:
self.dummy.resize_(n)
self.w.resize_(n)
self.w[cur:n].fill_(1)
self.b.resize_(n)
self.b[cur:n].zero_()
dummy = self.dummy[:n]
w = Variable(self.w[:n])
b = Variable(self.b[:n])
output = F.batch_norm(input, dummy, dummy, w, b, True, 0., self.eps)
return torch.addcmul(self.bias, 1, output.view(*shape), self.gain)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim
self.scaling = self.head_dim**-0.5
self._mask = None
self.in_proj_weight = Parameter(torch.Tensor(3*self.embed_dim, self.embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3*self.embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(self.embed_dim, embed_dim, bias=bias)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform(self.in_proj_weight.data)
nn.init.xavier_uniform(self.out_proj.weight.data)
if self.in_proj_bias is not None:
self.in_proj_bias.data.zero_()
def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0)
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if query.data_ptr() == key.data_ptr() == value.data_ptr():
# self-attention
q, k, v = self.in_proj_qkv(query)
elif key.data_ptr() == value.data_ptr():
# encoder-decoder attention
q = self.in_proj_q(query)
k, v = self.in_proj_kv(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q *= self.scaling
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if mask_future_timesteps:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None:
# don't attend to padding symbols
if key_padding_mask.max() > 0:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-math.inf,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
# average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
return attn, attn_weights
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim)
def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim)
def in_proj_v(self, value):
return self._in_proj(value, start=2*self.embed_dim)
def _in_proj(self, input, start=None, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
if end is not None:
weight = weight[:end, :]
if bias is not None:
bias = bias[:end]
if start is not None:
weight = weight[start:, :]
if bias is not None:
bias = bias[start:]
return F.linear(input, weight, bias)
def buffered_mask(self, tensor):
dim = tensor.size(-1)
if self._mask is None:
self._mask = torch.triu(tensor.new(dim, dim).fill_(-math.inf), 1)
if self._mask.size(0) < dim:
self._mask = torch.triu(self._mask.resize_(dim, dim).fill_(-math.inf), 1)
return self._mask[:dim, :dim]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
from torch.autograd import Variable
import torch.nn as nn
from fairseq import utils
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.left_pad = left_pad
self.register_buffer('range_buf', None)
self._cache = {}
self.register_buffer(
'weights',
SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
),
)
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
emb = torch.arange(num_embeddings).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
# recompute/expand embeddings if needed
bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len
if seq_len > self.weights.size(0):
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
).type_as(self.weights)
weights = Variable(self.weights)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
return weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
positions = Variable(utils.make_positions(input.data, self.padding_idx, self.left_pad))
return weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1)
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment