Commit c07362c6 authored by Halil Akin's avatar Halil Akin Committed by Facebook Github Bot
Browse files

Convert matmuls to quantizable nn.Linear modules (#1304)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1304

Pull Request resolved: https://github.com/pytorch/translate/pull/657

Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1065

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/889

We are converting matmuls to quantizable nn.Linear modules in this diff. First let's test profile after the diff to see how low level operations are changing.

Reviewed By: jmp84, edunov, lly-zero-one, jhcross

Differential Revision: D17964796

fbshipit-source-id: 3ddd3ff81fa1ea5864dded98e993f4fe3b71fe5e
parent fdf4c3e9
......@@ -39,14 +39,9 @@ class MultiheadAttention(nn.Module):
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......@@ -71,25 +66,30 @@ class MultiheadAttention(nn.Module):
@property
def in_proj_weight(self):
# TODO: Remove this backward compatibility code (in_proj_weight)
return torch.cat((self.q_proj_weight, self.k_proj_weight, self.v_proj_weight))
return torch.cat((self.q_proj.weight, self.k_proj.weight, self.v_proj.weight))
@property
def in_proj_bias(self):
# TODO: Remove this backward compatibility code (in_proj_bias)
return torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias))
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self):
if self.qkv_same_dim:
nn.init.xavier_uniform_(self.k_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj_weight, gain=1/math.sqrt(2))
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1/math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
......@@ -139,9 +139,9 @@ class MultiheadAttention(nn.Module):
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
......@@ -155,23 +155,23 @@ class MultiheadAttention(nn.Module):
saved_state = None
if self.self_attention:
q = self.in_proj_q(query)
k = self.in_proj_k(query)
v = self.in_proj_v(query)
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.in_proj_k(key)
v = self.in_proj_v(key)
k = self.k_proj(key)
v = self.v_proj(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
......@@ -284,26 +284,6 @@ class MultiheadAttention(nn.Module):
return attn, attn_weights
def in_proj_q(self, query):
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
def in_proj_k(self, key):
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
def in_proj_v(self, value):
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
......@@ -341,12 +321,21 @@ class MultiheadAttention(nn.Module):
if k.endswith(prefix + 'in_proj_weight'):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj_weight'] = state_dict[k][:dim]
items_to_add[prefix + 'k_proj_weight'] = state_dict[k][dim:2*dim]
items_to_add[prefix + 'v_proj_weight'] = state_dict[k][2*dim:]
items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim]
items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2*dim]
items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2*dim:]
keys_to_remove.append(k)
k_bias = prefix + 'in_proj_bias'
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj.bias'] = state_dict[k_bias][:dim]
items_to_add[prefix + 'k_proj.bias'] = state_dict[k_bias][dim:2*dim]
items_to_add[prefix + 'v_proj.bias'] = state_dict[k_bias][2*dim:]
keys_to_remove.append(prefix + 'in_proj_bias')
for k in keys_to_remove:
del state_dict[k]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment