"vscode:/vscode.git/clone" did not exist on "12d2c7379fa4a2be2286d2afbbefb49d97374e27"
Commit 265f42b7 authored by James Cross's avatar James Cross Committed by Facebook Github Bot
Browse files

multihead_attention: pre-transpose incremental state (#232)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/232

Though transpose operations are essentially free during PyTorch execution, they can result in costly operations when exported to Caffe2 inference nets via ONNX tracing, especially when applied repeatedly to large tensors.

For this reason, we update `MultiheadAttention` to store its incremental state with shape (bsz, num_heads, seq_len, head_dim), that is after transposing the projected input. This should result in non-trivially faster exported models without changing the semantics or speed of PyTorch execution.

Reviewed By: myleott

Differential Revision: D10186506

fbshipit-source-id: 8a42712423ee767ea49ed88d2a4653f900d14fba
parent b9e29a47
...@@ -108,23 +108,6 @@ class MultiheadAttention(nn.Module): ...@@ -108,23 +108,6 @@ class MultiheadAttention(nn.Module):
v = self.in_proj_v(value) v = self.in_proj_v(value)
q *= self.scaling q *= self.scaling
if saved_state is not None:
if 'prev_key' in saved_state:
if static_kv:
k = saved_state['prev_key']
else:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
if static_kv:
v = saved_state['prev_value']
else:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
if self.bias_k is not None: if self.bias_k is not None:
assert self.bias_v is not None assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
...@@ -135,16 +118,37 @@ class MultiheadAttention(nn.Module): ...@@ -135,16 +118,37 @@ class MultiheadAttention(nn.Module):
key_padding_mask = torch.cat( key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
src_len = k.size(0) q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(1)
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if self.add_zero_attn: if self.add_zero_attn:
src_len += 1 src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
...@@ -213,7 +217,7 @@ class MultiheadAttention(nn.Module): ...@@ -213,7 +217,7 @@ class MultiheadAttention(nn.Module):
input_buffer = self._get_input_buffer(incremental_state) input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None: if input_buffer is not None:
for k in input_buffer.keys(): for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(1, new_order) input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer) self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state): def _get_input_buffer(self, incremental_state):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment