Commit 4ac2c5f2 authored by Stephan Peitz's avatar Stephan Peitz Committed by Facebook Github Bot
Browse files

Implementation of the WeCNLP abstract "Cross+Self-Attention for Transformer Models" (#1097)

Summary:
This PR implements a new attention module which combines cross-attention (encoder-decoder attention) and the decoder self-attention. This work was accepted as an abstract at WeCNLP 2019 (https://www.wecnlp.ai/wecnlp-2019).

Cross+Self-Attention reduces the amount of parameter and increases the inference speed without any degradation in translation quality.
More details can be found in the attached [abstract](https://github.com/pytorch/fairseq/files/3561282/paper.pdf)
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1097

Differential Revision: D17653168

Pulled By: myleott

fbshipit-source-id: deb834c2c78a229d7418ffbfea20ba3ce252991c
parent ea1a410d
...@@ -122,6 +122,13 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -122,6 +122,13 @@ class TransformerModel(FairseqEncoderDecoderModel):
'Must be used with adaptive_loss criterion'), 'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections') help='sets adaptive softmax dropout for the tail projections')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
help='perform layer-wise attention (cross-attention or cross+self-attention)')
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -180,7 +187,12 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -180,7 +187,12 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens): def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(args, tgt_dict, embed_tokens) return TransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, 'no_cross_attention', False),
)
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
...@@ -211,6 +223,8 @@ class TransformerEncoder(FairseqEncoder): ...@@ -211,6 +223,8 @@ class TransformerEncoder(FairseqEncoder):
learned=args.encoder_learned_pos, learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None ) if not args.no_token_positional_embeddings else None
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend([
TransformerEncoderLayer(args) TransformerEncoderLayer(args)
...@@ -230,13 +244,15 @@ class TransformerEncoder(FairseqEncoder): ...@@ -230,13 +244,15 @@ class TransformerEncoder(FairseqEncoder):
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
return x, embed return x, embed
def forward(self, src_tokens, src_lengths, cls_input=None): def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
""" """
Args: Args:
src_tokens (LongTensor): tokens in the source language of shape src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)` `(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)` shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns: Returns:
dict: dict:
...@@ -244,7 +260,13 @@ class TransformerEncoder(FairseqEncoder): ...@@ -244,7 +260,13 @@ class TransformerEncoder(FairseqEncoder):
shape `(src_len, batch, embed_dim)` shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of - **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)` padding elements of shape `(batch, src_len)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
""" """
if self.layer_wise_attention:
return_all_hiddens = True
x, encoder_embedding = self.forward_embedding(src_tokens) x, encoder_embedding = self.forward_embedding(src_tokens)
# B x T x C -> T x B x C # B x T x C -> T x B x C
...@@ -255,17 +277,24 @@ class TransformerEncoder(FairseqEncoder): ...@@ -255,17 +277,24 @@ class TransformerEncoder(FairseqEncoder):
if not encoder_padding_mask.any(): if not encoder_padding_mask.any():
encoder_padding_mask = None encoder_padding_mask = None
encoder_states = [] if return_all_hiddens else None
# encoder layers # encoder layers
for layer in self.layers: for layer in self.layers:
x = layer(x, encoder_padding_mask) x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.layer_norm: if self.layer_norm:
x = self.layer_norm(x) x = self.layer_norm(x)
if return_all_hiddens:
encoder_states[-1] = x
return { return {
'encoder_out': x, # T x B x C 'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
'encoder_embedding': encoder_embedding, # B x T x C 'encoder_embedding': encoder_embedding, # B x T x C
'encoder_states': encoder_states, # List[T x B x C]
} }
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
...@@ -285,6 +314,9 @@ class TransformerEncoder(FairseqEncoder): ...@@ -285,6 +314,9 @@ class TransformerEncoder(FairseqEncoder):
if encoder_out['encoder_padding_mask'] is not None: if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order) encoder_out['encoder_padding_mask'].index_select(0, new_order)
if encoder_out.get('encoder_states', None) is not None:
for idx, state in enumerate(encoder_out['encoder_states']):
encoder_out['encoder_states'][idx] = state.index_select(1, new_order)
return encoder_out return encoder_out
def max_positions(self): def max_positions(self):
...@@ -293,6 +325,14 @@ class TransformerEncoder(FairseqEncoder): ...@@ -293,6 +325,14 @@ class TransformerEncoder(FairseqEncoder):
return self.max_source_positions return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions()) return min(self.max_source_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" """Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
...@@ -350,6 +390,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -350,6 +390,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
learned=args.decoder_learned_pos, learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None ) if not args.no_token_positional_embeddings else None
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend([
TransformerDecoderLayer(args, no_encoder_attn) TransformerDecoderLayer(args, no_encoder_attn)
...@@ -435,14 +478,26 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -435,14 +478,26 @@ class TransformerDecoder(FairseqIncrementalDecoder):
inner_states = [x] inner_states = [x]
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None
# decoder layers # decoder layers
for layer in self.layers: for idx, layer in enumerate(self.layers):
encoder_state = None
if encoder_out is not None:
if self.layer_wise_attention:
encoder_state = encoder_out['encoder_states'][idx]
else:
encoder_state = encoder_out['encoder_out']
x, attn = layer( x, attn = layer(
x, x,
encoder_out['encoder_out'] if encoder_out is not None else None, encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state, incremental_state,
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
self_attn_padding_mask=self_attn_padding_mask,
) )
inner_states.append(x) inner_states.append(x)
...@@ -553,6 +608,9 @@ def base_architecture(args): ...@@ -553,6 +608,9 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.adaptive_input = getattr(args, 'adaptive_input', False) args.adaptive_input = getattr(args, 'adaptive_input', False)
args.no_cross_attention = getattr(args, 'no_cross_attention', False)
args.cross_self_attention = getattr(args, 'cross_self_attention', False)
args.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
......
...@@ -186,8 +186,15 @@ class MultiheadAttention(nn.Module): ...@@ -186,8 +186,15 @@ class MultiheadAttention(nn.Module):
v = prev_value v = prev_value
else: else:
v = torch.cat((prev_value, v), dim=1) v = torch.cat((prev_value, v), dim=1)
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
prev_key_padding_mask = saved_state['prev_key_padding_mask']
if static_kv:
key_padding_mask = prev_key_padding_mask
else:
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_key_padding_mask'] = key_padding_mask
self._set_input_buffer(incremental_state, saved_state) self._set_input_buffer(incremental_state, saved_state)
...@@ -311,7 +318,8 @@ class MultiheadAttention(nn.Module): ...@@ -311,7 +318,8 @@ class MultiheadAttention(nn.Module):
input_buffer = self._get_input_buffer(incremental_state) input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None: if input_buffer is not None:
for k in input_buffer.keys(): for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(0, new_order) if input_buffer[k] is not None:
input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer) self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state): def _get_input_buffer(self, incremental_state):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
...@@ -134,13 +135,14 @@ class TransformerDecoderLayer(nn.Module): ...@@ -134,13 +135,14 @@ class TransformerDecoderLayer(nn.Module):
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
super().__init__() super().__init__()
self.embed_dim = args.decoder_embed_dim self.embed_dim = args.decoder_embed_dim
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.self_attn = MultiheadAttention( self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads, num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout, dropout=args.attention_dropout,
add_bias_kv=add_bias_kv, add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn, add_zero_attn=add_zero_attn,
self_attention=True self_attention=not self.cross_self_attention,
) )
self.dropout = args.dropout self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn( self.activation_fn = utils.get_activation_fn(
...@@ -208,13 +210,27 @@ class TransformerDecoderLayer(nn.Module): ...@@ -208,13 +210,27 @@ class TransformerDecoderLayer(nn.Module):
if prev_self_attn_state is not None: if prev_self_attn_state is not None:
if incremental_state is None: if incremental_state is None:
incremental_state = {} incremental_state = {}
prev_key, prev_value = prev_self_attn_state prev_key, prev_value = prev_self_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value} saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
self.self_attn._set_input_buffer(incremental_state, saved_state) self.self_attn._set_input_buffer(incremental_state, saved_state)
if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)):
if self_attn_mask is not None:
self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_()
self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1)
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn( x, attn = self.self_attn(
query=x, query=x,
key=x, key=y,
value=x, value=y,
key_padding_mask=self_attn_padding_mask, key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state, incremental_state=incremental_state,
need_weights=False, need_weights=False,
...@@ -230,9 +246,12 @@ class TransformerDecoderLayer(nn.Module): ...@@ -230,9 +246,12 @@ class TransformerDecoderLayer(nn.Module):
if prev_attn_state is not None: if prev_attn_state is not None:
if incremental_state is None: if incremental_state is None:
incremental_state = {} incremental_state = {}
prev_key, prev_value = prev_attn_state prev_key, prev_value = prev_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value} saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
self.encoder_attn._set_input_buffer(incremental_state, saved_state) self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn( x, attn = self.encoder_attn(
query=x, query=x,
key=encoder_out, key=encoder_out,
...@@ -256,7 +275,10 @@ class TransformerDecoderLayer(nn.Module): ...@@ -256,7 +275,10 @@ class TransformerDecoderLayer(nn.Module):
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace and incremental_state is not None: if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state) saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"] if self_attn_padding_mask is not None:
self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"]
else:
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state return x, attn, self_attn_state
return x, attn return x, attn
......
...@@ -154,6 +154,23 @@ class TestTranslation(unittest.TestCase): ...@@ -154,6 +154,23 @@ class TestTranslation(unittest.TestCase):
], run_validation=True) ], run_validation=True)
generate_main(data_dir) generate_main(data_dir)
def test_transformer_cross_self_attention(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'transformer_iwslt_de_en', [
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--no-cross-attention',
'--cross-self-attention',
'--layer-wise-attention',
], run_validation=True)
generate_main(data_dir, extra_flags=[])
def test_lightconv(self): def test_lightconv(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lightconv') as data_dir: with tempfile.TemporaryDirectory('test_lightconv') as data_dir:
...@@ -543,6 +560,10 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation' ...@@ -543,6 +560,10 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
def generate_main(data_dir, extra_flags=None): def generate_main(data_dir, extra_flags=None):
if extra_flags is None:
extra_flags = [
'--print-alignment',
]
generate_parser = options.get_generation_parser() generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch( generate_args = options.parse_args_and_arch(
generate_parser, generate_parser,
...@@ -554,7 +575,6 @@ def generate_main(data_dir, extra_flags=None): ...@@ -554,7 +575,6 @@ def generate_main(data_dir, extra_flags=None):
'--max-len-b', '5', '--max-len-b', '5',
'--gen-subset', 'valid', '--gen-subset', 'valid',
'--no-progress-bar', '--no-progress-bar',
'--print-alignment',
] + (extra_flags or []), ] + (extra_flags or []),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment