Commit b2374e52 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

caching v3 (cache keys, values, process only last time step) (#241)

- process only last time step during generation
- cache keys and values
- dont apply masking during generation
parent 81b47e7e
...@@ -15,9 +15,10 @@ from fairseq.modules import ( ...@@ -15,9 +15,10 @@ from fairseq.modules import (
LayerNorm, LearnedPositionalEmbedding, MultiheadAttention, LayerNorm, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
) )
from fairseq import utils
from . import ( from . import (
FairseqDecoder, FairseqEncoder, FairseqModel, FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
register_model, register_model_architecture, register_model, register_model_architecture,
) )
...@@ -159,7 +160,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -159,7 +160,7 @@ class TransformerEncoder(FairseqEncoder):
return state_dict return state_dict
class TransformerDecoder(FairseqDecoder): class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder.""" """Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary) super().__init__(dictionary)
...@@ -195,9 +196,16 @@ class TransformerDecoder(FairseqDecoder): ...@@ -195,9 +196,16 @@ class TransformerDecoder(FairseqDecoder):
elif name.endswith('bias'): elif name.endswith('bias'):
p.data.zero_() p.data.zero_()
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions # embed positions
positions = self.embed_positions(prev_output_tokens) positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens) x = self.embed_scale * self.embed_tokens(prev_output_tokens)
...@@ -209,7 +217,12 @@ class TransformerDecoder(FairseqDecoder): ...@@ -209,7 +217,12 @@ class TransformerDecoder(FairseqDecoder):
# decoder layers # decoder layers
for layer in self.layers: for layer in self.layers:
x, attn = layer(x, encoder_out['encoder_out'], encoder_out['encoder_padding_mask']) x, attn = layer(
x,
encoder_out['encoder_out'],
encoder_out['encoder_padding_mask'],
incremental_state,
)
# T x B x C -> B x T x C # T x B x C -> B x T x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
...@@ -222,10 +235,6 @@ class TransformerDecoder(FairseqDecoder): ...@@ -222,10 +235,6 @@ class TransformerDecoder(FairseqDecoder):
return x, attn return x, attn
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(new_order)
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -310,17 +319,34 @@ class TransformerDecoderLayer(nn.Module): ...@@ -310,17 +319,34 @@ class TransformerDecoderLayer(nn.Module):
self.fc2 = nn.Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.fc2 = nn.Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)]) self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask): def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x residual = x
x = self.maybe_layer_norm(0, x, before=True) x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, mask_future_timesteps=True)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
mask_future_timesteps=True,
incremental_state=incremental_state,
need_weights=False,
)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(0, x, after=True) x = self.maybe_layer_norm(0, x, after=True)
residual = x residual = x
x = self.maybe_layer_norm(1, x, before=True) x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn(query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(1, x, after=True) x = self.maybe_layer_norm(1, x, after=True)
......
...@@ -46,7 +46,8 @@ class MultiheadAttention(nn.Module): ...@@ -46,7 +46,8 @@ class MultiheadAttention(nn.Module):
self.in_proj_bias.data.zero_() self.in_proj_bias.data.zero_()
def forward(self, query, key, value, mask_future_timesteps=False, def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None): key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False):
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for Self-attention can be implemented by passing in the same arguments for
...@@ -55,22 +56,38 @@ class MultiheadAttention(nn.Module): ...@@ -55,22 +56,38 @@ class MultiheadAttention(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape: the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s. batch x src_len, where padding elements are indicated by 1s.
""" """
src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0) qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim] assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size() assert key.size() == value.size()
if key_padding_mask is not None: if incremental_state is not None:
assert key_padding_mask.size(0) == bsz saved_state = self._get_input_buffer(incremental_state)
assert key_padding_mask.size(1) == src_len if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert kv_same and not qkv_same
key = value = None
else:
saved_state = None
if query.data_ptr() == key.data_ptr() == value.data_ptr(): if qkv_same:
# self-attention # self-attention
q, k, v = self.in_proj_qkv(query) q, k, v = self.in_proj_qkv(query)
elif key.data_ptr() == value.data_ptr(): elif kv_same:
# encoder-decoder attention # encoder-decoder attention
q = self.in_proj_q(query) q = self.in_proj_q(query)
if key is None:
assert value is None
# this will allow us to concat it with previous value and get
# just get the previous value
k = v = q.new(0)
else:
k, v = self.in_proj_kv(key) k, v = self.in_proj_kv(key)
else: else:
q = self.in_proj_q(query) q = self.in_proj_q(query)
...@@ -78,17 +95,34 @@ class MultiheadAttention(nn.Module): ...@@ -78,17 +95,34 @@ class MultiheadAttention(nn.Module):
v = self.in_proj_v(value) v = self.in_proj_v(value)
q *= self.scaling q *= self.scaling
if saved_state is not None:
if 'prev_key' in saved_state:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if mask_future_timesteps:
# only apply masking at training time (when incremental state is None)
if mask_future_timesteps and incremental_state is None:
assert query.size() == key.size(), \ assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention' 'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0) attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None: if key_padding_mask is not None and incremental_state is None:
# don't attend to padding symbols # don't attend to padding symbols
if utils.item(key_padding_mask.max()) > 0: if utils.item(key_padding_mask.max()) > 0:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
...@@ -105,9 +139,12 @@ class MultiheadAttention(nn.Module): ...@@ -105,9 +139,12 @@ class MultiheadAttention(nn.Module):
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn) attn = self.out_proj(attn)
if need_weights:
# average attention weights over heads # average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads attn_weights = attn_weights.sum(dim=1) / self.num_heads
else:
attn_weights = None
return attn, attn_weights return attn, attn_weights
...@@ -146,3 +183,27 @@ class MultiheadAttention(nn.Module): ...@@ -146,3 +183,27 @@ class MultiheadAttention(nn.Module):
if self._mask.size(0) < dim: if self._mask.size(0) < dim:
self._mask = torch.triu(self._mask.resize_(dim, dim).fill_(-math.inf), 1) self._mask = torch.triu(self._mask.resize_(dim, dim).fill_(-math.inf), 1)
return self._mask[:dim, :dim] return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment