Commit b2374e52 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

caching v3 (cache keys, values, process only last time step) (#241)

- process only last time step during generation
- cache keys and values
- dont apply masking during generation
parent 81b47e7e
......@@ -15,9 +15,10 @@ from fairseq.modules import (
LayerNorm, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
)
from fairseq import utils
from . import (
FairseqDecoder, FairseqEncoder, FairseqModel,
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
register_model, register_model_architecture,
)
......@@ -159,7 +160,7 @@ class TransformerEncoder(FairseqEncoder):
return state_dict
class TransformerDecoder(FairseqDecoder):
class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
......@@ -195,9 +196,16 @@ class TransformerDecoder(FairseqDecoder):
elif name.endswith('bias'):
p.data.zero_()
def forward(self, prev_output_tokens, encoder_out):
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions
positions = self.embed_positions(prev_output_tokens)
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
......@@ -209,7 +217,12 @@ class TransformerDecoder(FairseqDecoder):
# decoder layers
for layer in self.layers:
x, attn = layer(x, encoder_out['encoder_out'], encoder_out['encoder_padding_mask'])
x, attn = layer(
x,
encoder_out['encoder_out'],
encoder_out['encoder_padding_mask'],
incremental_state,
)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
......@@ -222,10 +235,6 @@ class TransformerDecoder(FairseqDecoder):
return x, attn
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(new_order)
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
......@@ -310,17 +319,34 @@ class TransformerDecoderLayer(nn.Module):
self.fc2 = nn.Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask):
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, mask_future_timesteps=True)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
mask_future_timesteps=True,
incremental_state=incremental_state,
need_weights=False,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn(query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
......
......@@ -46,7 +46,8 @@ class MultiheadAttention(nn.Module):
self.in_proj_bias.data.zero_()
def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None):
key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
......@@ -55,40 +56,73 @@ class MultiheadAttention(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0)
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert kv_same and not qkv_same
key = value = None
else:
saved_state = None
if query.data_ptr() == key.data_ptr() == value.data_ptr():
if qkv_same:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif key.data_ptr() == value.data_ptr():
elif kv_same:
# encoder-decoder attention
q = self.in_proj_q(query)
k, v = self.in_proj_kv(key)
if key is None:
assert value is None
# this will allow us to concat it with previous value and get
# just get the previous value
k = v = q.new(0)
else:
k, v = self.in_proj_kv(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q *= self.scaling
if saved_state is not None:
if 'prev_key' in saved_state:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if mask_future_timesteps:
# only apply masking at training time (when incremental state is None)
if mask_future_timesteps and incremental_state is None:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None:
if key_padding_mask is not None and incremental_state is None:
# don't attend to padding symbols
if utils.item(key_padding_mask.max()) > 0:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
......@@ -105,9 +139,12 @@ class MultiheadAttention(nn.Module):
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
# average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
if need_weights:
# average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
else:
attn_weights = None
return attn, attn_weights
......@@ -146,3 +183,27 @@ class MultiheadAttention(nn.Module):
if self._mask.size(0) < dim:
self._mask = torch.triu(self._mask.resize_(dim, dim).fill_(-math.inf), 1)
return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment