Commit fdf4c3e9 authored by Halil Akin's avatar Halil Akin Committed by Facebook Github Bot
Browse files

Simplify fairseq multihead attention (#888)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/888

We want to simplify multihead attention and get rid of the dynamic in_proj_weight logic. Sending the diff early for feedback, will have further changes as I try to fix breaking tests

Reviewed By: edunov

Differential Revision: D17912661

fbshipit-source-id: 0e6319fc694d8ec5187d1c2fefe5839d9d522186
parent 5b086a0c
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
...@@ -38,12 +39,9 @@ class MultiheadAttention(nn.Module): ...@@ -38,12 +39,9 @@ class MultiheadAttention(nn.Module):
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size' 'value to be of the same size'
if self.qkv_same_dim: self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
else: self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if bias: if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
...@@ -70,12 +68,19 @@ class MultiheadAttention(nn.Module): ...@@ -70,12 +68,19 @@ class MultiheadAttention(nn.Module):
else: else:
self.enable_torch_version = False self.enable_torch_version = False
@property
def in_proj_weight(self):
# TODO: Remove this backward compatibility code (in_proj_weight)
return torch.cat((self.q_proj_weight, self.k_proj_weight, self.v_proj_weight))
def prepare_for_onnx_export_(self): def prepare_for_onnx_export_(self):
self.onnx_trace = True self.onnx_trace = True
def reset_parameters(self): def reset_parameters(self):
if self.qkv_same_dim: if self.qkv_same_dim:
nn.init.xavier_uniform_(self.in_proj_weight) nn.init.xavier_uniform_(self.k_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj_weight, gain=1/math.sqrt(2))
else: else:
nn.init.xavier_uniform_(self.k_proj_weight) nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight) nn.init.xavier_uniform_(self.v_proj_weight)
...@@ -126,27 +131,17 @@ class MultiheadAttention(nn.Module): ...@@ -126,27 +131,17 @@ class MultiheadAttention(nn.Module):
assert list(query.size()) == [tgt_len, bsz, embed_dim] assert list(query.size()) == [tgt_len, bsz, embed_dim]
if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv: if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
if self.qkv_same_dim: return F.multi_head_attention_forward(query, key, value,
return F.multi_head_attention_forward(query, key, value, self.embed_dim, self.num_heads,
self.embed_dim, self.num_heads, torch.empty([0]),
self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v,
self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout,
self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias,
self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, need_weights,
self.training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True,
attn_mask) q_proj_weight=self.q_proj_weight,
else: k_proj_weight=self.k_proj_weight,
return F.multi_head_attention_forward(query, key, value, v_proj_weight=self.v_proj_weight)
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
if incremental_state is not None: if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state) saved_state = self._get_input_buffer(incremental_state)
...@@ -160,8 +155,9 @@ class MultiheadAttention(nn.Module): ...@@ -160,8 +155,9 @@ class MultiheadAttention(nn.Module):
saved_state = None saved_state = None
if self.self_attention: if self.self_attention:
# self-attention q = self.in_proj_q(query)
q, k, v = self.in_proj_qkv(query) k = self.in_proj_k(query)
v = self.in_proj_v(query)
elif self.encoder_decoder_attention: elif self.encoder_decoder_attention:
# encoder-decoder attention # encoder-decoder attention
q = self.in_proj_q(query) q = self.in_proj_q(query)
...@@ -288,45 +284,25 @@ class MultiheadAttention(nn.Module): ...@@ -288,45 +284,25 @@ class MultiheadAttention(nn.Module):
return attn, attn_weights return attn, attn_weights
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_q(self, query): def in_proj_q(self, query):
if self.qkv_same_dim: bias = self.in_proj_bias
return self._in_proj(query, end=self.embed_dim) if bias is not None:
else: bias = bias[:self.embed_dim]
bias = self.in_proj_bias return F.linear(query, self.q_proj_weight, bias)
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
def in_proj_k(self, key): def in_proj_k(self, key):
if self.qkv_same_dim: weight = self.k_proj_weight
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) bias = self.in_proj_bias
else: if bias is not None:
weight = self.k_proj_weight bias = bias[self.embed_dim:2 * self.embed_dim]
bias = self.in_proj_bias return F.linear(key, weight, bias)
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
def in_proj_v(self, value): def in_proj_v(self, value):
if self.qkv_same_dim: weight = self.v_proj_weight
return self._in_proj(value, start=2 * self.embed_dim)
else:
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias bias = self.in_proj_bias
weight = weight[start:end, :]
if bias is not None: if bias is not None:
bias = bias[start:end] bias = bias[2 * self.embed_dim:]
return F.linear(input, weight, bias) return F.linear(value, weight, bias)
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation).""" """Reorder buffered internal state (for incremental generation)."""
...@@ -354,3 +330,27 @@ class MultiheadAttention(nn.Module): ...@@ -354,3 +330,27 @@ class MultiheadAttention(nn.Module):
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
return attn_weights return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
# TODO: Remove this backward compatibility code (in_proj_weight)
# here, we convert in_proj_weight to individual q,k,v weights
prefix = name + '.' if name != '' else ''
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + 'in_proj_weight'):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj_weight'] = state_dict[k][:dim]
items_to_add[prefix + 'k_proj_weight'] = state_dict[k][dim:2*dim]
items_to_add[prefix + 'v_proj_weight'] = state_dict[k][2*dim:]
keys_to_remove.append(k)
for k in keys_to_remove:
del state_dict[k]
for key, value in items_to_add.items():
state_dict[key] = value
return state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment