Commit c21a6e29 authored by Myle Ott's avatar Myle Ott
Browse files

Move positional embeddings into LearnedPositionalEmbedding module

parent 185a0df5
...@@ -13,26 +13,11 @@ import torch.nn as nn ...@@ -13,26 +13,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data import LanguagePairDataset from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LinearizedConvolution from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
def make_positions(tokens, padding_idx, left_pad, offset=0):
seqlen = tokens.size(1)
if not hasattr(make_positions, 'range'):
make_positions.range = tokens.new()
if make_positions.range.numel() < offset + seqlen:
# offset positions by the padding index
torch.arange(padding_idx + 1, padding_idx + 1 + offset + seqlen,
out=make_positions.range)
mask = tokens.ne(padding_idx)
positions = make_positions.range[offset:offset+seqlen].expand_as(tokens)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tokens.clone().masked_scatter_(mask, positions[mask])
class FConvModel(FairseqModel): class FConvModel(FairseqModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
...@@ -51,7 +36,8 @@ class FConvEncoder(FairseqEncoder): ...@@ -51,7 +36,8 @@ class FConvEncoder(FairseqEncoder):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx) self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
...@@ -68,11 +54,8 @@ class FConvEncoder(FairseqEncoder): ...@@ -68,11 +54,8 @@ class FConvEncoder(FairseqEncoder):
self.fc2 = Linear(in_channels, embed_dim) self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens): def forward(self, src_tokens):
positions = Variable(make_positions(src_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE))
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(positions) x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
input_embedding = x input_embedding = x
...@@ -106,7 +89,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -106,7 +89,7 @@ class FConvEncoder(FairseqEncoder):
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1 return self.embed_positions.max_positions()
class AttentionLayer(nn.Module): class AttentionLayer(nn.Module):
...@@ -170,7 +153,8 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -170,7 +153,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx) self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList() self.projections = nn.ModuleList()
...@@ -190,32 +174,18 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -190,32 +174,18 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, input_tokens, encoder_out): def forward(self, input_tokens, encoder_out):
if self._is_incremental_eval:
return self.incremental_forward(input_tokens, encoder_out)
else:
return self.batch_forward(input_tokens, encoder_out)
def batch_forward(self, input_tokens, encoder_out):
"""Forward pass for decoding multiple time steps in batch mode."""
positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_TARGET))
return self._forward(input_tokens, positions, encoder_out)
def incremental_forward(self, input_tokens, encoder_out):
"""Forward pass for one time step."""
# positions is the same for every token when decoding a single step
positions = Variable(input_tokens.data.new(1, 1).fill_(
self.dictionary.pad() + input_tokens.size(1)))
# keep only the last token for incremental forward pass
return self._forward(input_tokens[:, -1:], positions, encoder_out)
def _forward(self, input_tokens, positions, encoder_out):
# split and transpose encoder outputs # split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out) encoder_a, encoder_b = self._split_encoder_out(encoder_out)
# embed positions
positions = self.embed_positions(input_tokens)
if self._is_incremental_eval:
# keep only the last token for incremental forward pass
input_tokens = input_tokens[:, -1:]
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(input_tokens) + self.embed_positions(positions) x = self.embed_tokens(input_tokens) + positions
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x target_embedding = x
...@@ -268,7 +238,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -268,7 +238,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1 return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2: if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
...@@ -308,6 +278,12 @@ def Embedding(num_embeddings, embedding_dim, padding_idx): ...@@ -308,6 +278,12 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
m.weight.data.normal_(0, 0.1)
return m
def Linear(in_features, out_features, dropout=0): def Linear(in_features, out_features, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)""" """Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features) m = nn.Linear(in_features, out_features)
......
...@@ -9,11 +9,13 @@ ...@@ -9,11 +9,13 @@
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
from .learned_positional_embedding import LearnedPositionalEmbedding
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
__all__ = [ __all__ = [
'BeamableMM', 'BeamableMM',
'ConvTBC', 'ConvTBC',
'GradMultiply', 'GradMultiply',
'LearnedPositionalEmbedding',
'LinearizedConvolution', 'LinearizedConvolution',
] ]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
self.register_buffer('range_buf', None)
self._is_incremental_eval = False
def incremental_eval(self, mode=True):
self._is_incremental_eval = mode
def forward(self, input):
"""Input is expected to be of size [bsz x seqlen]."""
if self._is_incremental_eval:
# positions is the same for every token when decoding a single step
positions = Variable(
input.data.new(1, 1).fill_(self.padding_idx + input.size(1)))
else:
positions = Variable(self.make_positions(input.data))
return super().forward(positions)
def max_positions(self):
"""Maximum number of supported positions."""
return self.num_embeddings - self.padding_idx - 1
def make_positions(self, input):
"""Replace non-padding symbols with their position numbers."""
if self.range_buf is None:
self.range_buf = input.new()
seqlen = input.size(1)
if self.range_buf.numel() < seqlen:
# offset positions by the padding index
torch.arange(self.padding_idx + 1, self.padding_idx + 1 + seqlen,
out=self.range_buf)
mask = input.ne(self.padding_idx)
positions = self.range_buf[:seqlen].expand_as(input)
if self.left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return input.clone().masked_scatter_(mask, positions[mask])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment