Commit a03fe6fa authored by Sara Hanson's avatar Sara Hanson Committed by Facebook Github Bot
Browse files

Implement sparse transformer fixed attention pattern (#804)

Summary:
Pull Request resolved: https://github.com/facebookresearch/pytext/pull/804

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/746

Pull Request resolved: https://github.com/pytorch/fairseq/pull/894

Adding an implementation of the sparse transformer to multi-head attention using the fixed attention pattern specified https://arxiv.org/pdf/1904.10509.pdf. The sparse_mask masks out words using -inf; after softmax, -inf becomes 0. Thus, a mask does not need to be re-calculated and re-applied when multiplying attn_weights and values.

Four inputs are added to the config: sparse, is_bidirectional, stride, expressivity. If we are using the sparse transformer, is_bidirectional, stride, and expressivity must be specified (there are defaults). If is_bidirectional is False, the mask values using the fixed attention pattern described in the paper. If is_bidirectional is True, subset one includes all values in the current stride window and a summary from every stride window--all other values are masked. Stride (L in the paper) controls the window size and expressivity (c in the paper) controls the size of the summary.

Reviewed By: borguz

Differential Revision: D16042988

fbshipit-source-id: c59166dc7cfe89187a256e4076000c2458842fd5
parent e8d609a8
......@@ -40,7 +40,6 @@ class MultiheadAttention(nn.Module):
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'
if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
......@@ -102,7 +101,6 @@ class MultiheadAttention(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
......@@ -217,6 +215,8 @@ class MultiheadAttention(nn.Module):
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
......@@ -327,3 +327,6 @@ class MultiheadAttention(nn.Module):
'attn_state',
buffer,
)
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
return attn_weights
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
from .multihead_attention import MultiheadAttention
class SparseMultiheadAttention(MultiheadAttention):
""" Sparse Multi-Headed Attention.
"Generating Long Sequences with Sparse Transformers". Implements
fixed factorized self attention, where l=stride and c=expressivity.
A(1) includes all words in the stride window and A(2) takes a summary of c
words from the end of each stride window.
If is_bidirectional=False, we do not include any words past the current word,
as in the paper.
"""
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False, stride=32, expressivity=8, is_bidirectional=True):
super().__init__(
embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv,
add_zero_attn, self_attention, encoder_decoder_attention
)
self.is_bidirectional = is_bidirectional
self.stride = stride
self.expressivity = expressivity
assert(self.stride > 0 and self.stride >= self.expressivity)
# Used for Ai(2) calculations - beginning of [l-c, l] range
def compute_checkpoint(self, word_index):
if word_index % self.stride == 0 and word_index is not 0:
checkpoint_index = word_index - self.expressivity
else:
checkpoint_index = (
math.floor(word_index / self.stride) * self.stride
+ self.stride - self.expressivity
)
return checkpoint_index
# Computes Ai(2)
def compute_subset_summaries(self, absolute_max):
checkpoint_index = self.compute_checkpoint(0)
subset_two = set()
while checkpoint_index <= absolute_max-1:
summary = set(range(checkpoint_index, min(
checkpoint_index+self.expressivity+1, absolute_max)
))
subset_two = subset_two.union(summary)
checkpoint_index = self.compute_checkpoint(checkpoint_index+self.stride)
return subset_two
# Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf
def compute_fixed_attention_subset(self, word_index, tgt_len):
# +1s account for range function; [min, max) -> [min, max]
if not self.is_bidirectional:
absolute_max = word_index + 1
else:
absolute_max = tgt_len
# Subset 1 - whole window
rounded_index = math.floor((word_index + self.stride) / self.stride) * self.stride
if word_index % self.stride == 0 and word_index is not 0:
subset_one = set(range(word_index-self.stride, min(absolute_max, word_index+1)))
else:
subset_one = set(range(max(0, rounded_index - self.stride), min(
absolute_max, rounded_index+1))
)
# Subset 2 - summary per window
# If bidirectional, subset 2 is the same for every index
subset_two = set()
if not self.is_bidirectional:
subset_two = self.compute_subset_summaries(absolute_max)
return subset_one.union(subset_two)
# Compute sparse mask - if bidirectional, can pre-compute and store
def buffered_sparse_mask(self, tensor, tgt_len, src_len):
assert(tgt_len > self.stride)
sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float('-inf'))
# If bidirectional, subset 2 is the same for every index
subset_summaries = set()
if self.is_bidirectional:
subset_summaries = self.compute_subset_summaries(tgt_len)
for i in range(tgt_len):
fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len)
fixed_attention_subset = fixed_attention_subset.union(subset_summaries)
included_word_indices = torch.LongTensor(list(fixed_attention_subset))
sparse_mask[i].index_fill_(0, included_word_indices, 0)
return sparse_mask.type_as(tensor)
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len)
sparse_mask = sparse_mask.unsqueeze(0).expand(bsz * self.num_heads, tgt_len, src_len)
attn_weights += sparse_mask
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
from fairseq.modules import TransformerSentenceEncoder
from fairseq.modules.sparse_transformer_sentence_encoder_layer import SparseTransformerSentenceEncoderLayer
class SparseTransformerSentenceEncoder(TransformerSentenceEncoder):
"""
Sparse implementation of the TransformerSentenceEncoder
- see SparseMultiheadAttention
"""
def __init__(
self,
padding_idx: int,
vocab_size: int,
num_encoder_layers: int = 6,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
max_seq_len: int = 256,
num_segments: int = 2,
use_position_embeddings: bool = True,
offset_positions_by_padding: bool = True,
encoder_normalize_before: bool = False,
apply_bert_init: bool = False,
activation_fn: str = "relu",
learned_pos_embedding: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
embed_scale: float = None,
freeze_embeddings: bool = False,
n_trans_layers_to_freeze: int = 0,
export: bool = False,
is_bidirectional: bool = True,
stride: int = 32,
expressivity: int = 8,
) -> None:
super().__init__(
padding_idx, vocab_size, num_encoder_layers, embedding_dim,
ffn_embedding_dim, num_attention_heads, dropout, attention_dropout,
activation_dropout, max_seq_len, num_segments, use_position_embeddings,
offset_positions_by_padding, encoder_normalize_before, apply_bert_init,
activation_fn, learned_pos_embedding, add_bias_kv, add_zero_attn,
embed_scale, freeze_embeddings, n_trans_layers_to_freeze, export
)
self.layers = nn.ModuleList(
[
SparseTransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=self.dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
export=export,
is_bidirectional=is_bidirectional,
stride=stride,
expressivity=expressivity,
)
for _ in range(num_encoder_layers)
]
)
def freeze_module_params(m):
if m is not None:
for p in m.parameters():
p.requires_grad = False
for layer in range(n_trans_layers_to_freeze):
freeze_module_params(self.layers[layer])
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.modules import TransformerSentenceEncoderLayer
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
"""
Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention)
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = 'relu',
add_bias_kv: bool = False,
add_zero_attn: bool = False,
export: bool = False,
is_bidirectional: bool = True,
stride: int = 32,
expressivity: int = 8,
) -> None:
super().__init__(
embedding_dim, ffn_embedding_dim, num_attention_heads, dropout,
attention_dropout, activation_dropout, activation_fn, add_bias_kv,
add_zero_attn, export
)
self.self_attn = SparseMultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True,
is_bidirectional=is_bidirectional,
stride=stride,
expressivity=expressivity,
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import unittest
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
class TestSparseMultiheadAttention(unittest.TestCase):
def test_sparse_multihead_attention(self):
attn_weights = torch.randn(1, 8, 8)
bidirectional_sparse_mask = torch.tensor([
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0]
])
bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True)
bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask))
sparse_mask = torch.tensor([
[0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'),
float('-inf'), float('-inf')],
[0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
[0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
[0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')],
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
])
attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False)
attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
torch.all(torch.eq(attention_sparse_mask, sparse_mask))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment