"docs/source/api/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "98c1117d00edd38d72610d6a87c0c8d706873863"
Commit e265c239 authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Make Fairseq compatible with pre-computed position tensors (#570)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/570

Pull Request resolved: https://github.com/pytorch/fairseq/pull/731

Currently the LearnedPositionalEmbedding module computes the position tensor based on the input data. However this really doesnt work for XLM where we have different behavior based on the Masked LM and Translation LM. In this diff I keep the same default behavior for LearnedPositionalEmbedding as before but add the ability for these models to work with pre-computed position tensors.

Reviewed By: myleott

Differential Revision: D15305474

fbshipit-source-id: de7d908245a2a620b58d36055211600a08f2d1dc
parent ba989ed1
...@@ -11,26 +11,41 @@ from fairseq import utils ...@@ -11,26 +11,41 @@ from fairseq import utils
class LearnedPositionalEmbedding(nn.Embedding): class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size. """
This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored. Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
""" """
def __init__(self, num_embeddings, embedding_dim, padding_idx): def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
):
super().__init__(num_embeddings, embedding_dim, padding_idx) super().__init__(num_embeddings, embedding_dim, padding_idx)
self.onnx_trace = False self.onnx_trace = False
def forward(self, input, incremental_state=None): def forward(self, input, incremental_state=None, positions=None):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
if incremental_state is not None: assert (
# positions is the same for every token when decoding a single step (positions is None) or (self.padding_idx is None)
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) ), "If positions is pre-computed then padding_idx should not be set."
else:
positions = utils.make_positions( if positions is None:
input.data, self.padding_idx, onnx_trace=self.onnx_trace, if incremental_state is not None:
) # positions is the same for every token when decoding a single step
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = utils.make_positions(
input.data, self.padding_idx, onnx_trace=self.onnx_trace,
)
return super().forward(positions) return super().forward(positions)
def max_positions(self): def max_positions(self):
"""Maximum number of supported positions.""" """Maximum number of supported positions."""
return self.num_embeddings - self.padding_idx - 1 if self.padding_idx is not None:
return self.num_embeddings - self.padding_idx - 1
else:
return self.num_embeddings
...@@ -11,13 +11,23 @@ from .learned_positional_embedding import LearnedPositionalEmbedding ...@@ -11,13 +11,23 @@ from .learned_positional_embedding import LearnedPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False): def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
learned: bool = False,
):
if learned: if learned:
m = LearnedPositionalEmbedding( # if padding_idx is specified then offset the embedding ids by
num_embeddings + padding_idx + 1, embedding_dim, padding_idx, # this index and adjust num_embeddings appropriately
) # TODO: The right place for this offset would be inside
# LearnedPositionalEmbedding. Move this there for a cleaner implementation.
if padding_idx is not None:
num_embeddings = num_embeddings + padding_idx + 1
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0) if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
else: else:
m = SinusoidalPositionalEmbedding( m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1, embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from typing import Tuple from typing import Tuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -81,6 +81,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -81,6 +81,7 @@ class TransformerSentenceEncoder(nn.Module):
max_seq_len: int = 256, max_seq_len: int = 256,
num_segments: int = 2, num_segments: int = 2,
use_position_embeddings: bool = True, use_position_embeddings: bool = True,
offset_positions_by_padding: bool = True,
encoder_normalize_before: bool = False, encoder_normalize_before: bool = False,
apply_bert_init: bool = False, apply_bert_init: bool = False,
activation_fn: str = 'relu', activation_fn: str = 'relu',
...@@ -116,7 +117,10 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -116,7 +117,10 @@ class TransformerSentenceEncoder(nn.Module):
PositionalEmbedding( PositionalEmbedding(
self.max_seq_len, self.max_seq_len,
self.embedding_dim, self.embedding_dim,
self.padding_idx, padding_idx=(
self.padding_idx if offset_positions_by_padding
else None
),
learned=self.learned_pos_embedding, learned=self.learned_pos_embedding,
) )
if self.use_position_embeddings if self.use_position_embeddings
...@@ -154,6 +158,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -154,6 +158,7 @@ class TransformerSentenceEncoder(nn.Module):
tokens: torch.Tensor, tokens: torch.Tensor,
segment_labels: torch.Tensor, segment_labels: torch.Tensor,
last_state_only: bool = False, last_state_only: bool = False,
positions: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# compute padding mask. This is needed for multi-head attention # compute padding mask. This is needed for multi-head attention
...@@ -162,11 +167,12 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -162,11 +167,12 @@ class TransformerSentenceEncoder(nn.Module):
padding_mask = None padding_mask = None
x = self.embed_tokens(tokens) x = self.embed_tokens(tokens)
if self.embed_scale is not None: if self.embed_scale is not None:
x *= self.embed_scale x *= self.embed_scale
if self.embed_positions is not None: if self.embed_positions is not None:
x += self.embed_positions(tokens) x += self.embed_positions(tokens, positions=positions)
if self.segment_embeddings is not None and segment_labels is not None: if self.segment_embeddings is not None and segment_labels is not None:
x += self.segment_embeddings(segment_labels) x += self.segment_embeddings(segment_labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment