Commit c07362c6 authored by Halil Akin's avatar Halil Akin Committed by Facebook Github Bot
Browse files

Convert matmuls to quantizable nn.Linear modules (#1304)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1304

Pull Request resolved: https://github.com/pytorch/translate/pull/657

Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1065

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/889

We are converting matmuls to quantizable nn.Linear modules in this diff. First let's test profile after the diff to see how low level operations are changing.

Reviewed By: jmp84, edunov, lly-zero-one, jhcross

Differential Revision: D17964796

fbshipit-source-id: 3ddd3ff81fa1ea5864dded98e993f4fe3b71fe5e
parent fdf4c3e9
...@@ -39,14 +39,9 @@ class MultiheadAttention(nn.Module): ...@@ -39,14 +39,9 @@ class MultiheadAttention(nn.Module):
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size' 'value to be of the same size'
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -71,25 +66,30 @@ class MultiheadAttention(nn.Module): ...@@ -71,25 +66,30 @@ class MultiheadAttention(nn.Module):
@property @property
def in_proj_weight(self): def in_proj_weight(self):
# TODO: Remove this backward compatibility code (in_proj_weight) # TODO: Remove this backward compatibility code (in_proj_weight)
return torch.cat((self.q_proj_weight, self.k_proj_weight, self.v_proj_weight)) return torch.cat((self.q_proj.weight, self.k_proj.weight, self.v_proj.weight))
@property
def in_proj_bias(self):
# TODO: Remove this backward compatibility code (in_proj_bias)
return torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias))
def prepare_for_onnx_export_(self): def prepare_for_onnx_export_(self):
self.onnx_trace = True self.onnx_trace = True
def reset_parameters(self): def reset_parameters(self):
if self.qkv_same_dim: if self.qkv_same_dim:
nn.init.xavier_uniform_(self.k_proj_weight, gain=1/math.sqrt(2)) # Empirically observed the convergence to be much better with
nn.init.xavier_uniform_(self.v_proj_weight, gain=1/math.sqrt(2)) # the scaled initialization
nn.init.xavier_uniform_(self.q_proj_weight, gain=1/math.sqrt(2)) nn.init.xavier_uniform_(self.k_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1/math.sqrt(2))
else: else:
nn.init.xavier_uniform_(self.k_proj_weight) nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj_weight) nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj_weight) nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None: nn.init.constant_(self.out_proj.bias, 0.)
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None: if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k) nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None: if self.bias_v is not None:
...@@ -139,9 +139,9 @@ class MultiheadAttention(nn.Module): ...@@ -139,9 +139,9 @@ class MultiheadAttention(nn.Module):
self.out_proj.weight, self.out_proj.bias, self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights, self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True, attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj_weight, k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj_weight) v_proj_weight=self.v_proj.weight)
if incremental_state is not None: if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state) saved_state = self._get_input_buffer(incremental_state)
...@@ -155,23 +155,23 @@ class MultiheadAttention(nn.Module): ...@@ -155,23 +155,23 @@ class MultiheadAttention(nn.Module):
saved_state = None saved_state = None
if self.self_attention: if self.self_attention:
q = self.in_proj_q(query) q = self.q_proj(query)
k = self.in_proj_k(query) k = self.k_proj(query)
v = self.in_proj_v(query) v = self.v_proj(query)
elif self.encoder_decoder_attention: elif self.encoder_decoder_attention:
# encoder-decoder attention # encoder-decoder attention
q = self.in_proj_q(query) q = self.q_proj(query)
if key is None: if key is None:
assert value is None assert value is None
k = v = None k = v = None
else: else:
k = self.in_proj_k(key) k = self.k_proj(key)
v = self.in_proj_v(key) v = self.v_proj(key)
else: else:
q = self.in_proj_q(query) q = self.q_proj(query)
k = self.in_proj_k(key) k = self.k_proj(key)
v = self.in_proj_v(value) v = self.v_proj(value)
q *= self.scaling q *= self.scaling
if self.bias_k is not None: if self.bias_k is not None:
...@@ -284,26 +284,6 @@ class MultiheadAttention(nn.Module): ...@@ -284,26 +284,6 @@ class MultiheadAttention(nn.Module):
return attn, attn_weights return attn, attn_weights
def in_proj_q(self, query):
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
def in_proj_k(self, key):
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
def in_proj_v(self, value):
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation).""" """Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state) input_buffer = self._get_input_buffer(incremental_state)
...@@ -341,12 +321,21 @@ class MultiheadAttention(nn.Module): ...@@ -341,12 +321,21 @@ class MultiheadAttention(nn.Module):
if k.endswith(prefix + 'in_proj_weight'): if k.endswith(prefix + 'in_proj_weight'):
# in_proj_weight used to be q + k + v with same dimensions # in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3) dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj_weight'] = state_dict[k][:dim] items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim]
items_to_add[prefix + 'k_proj_weight'] = state_dict[k][dim:2*dim] items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2*dim]
items_to_add[prefix + 'v_proj_weight'] = state_dict[k][2*dim:] items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2*dim:]
keys_to_remove.append(k) keys_to_remove.append(k)
k_bias = prefix + 'in_proj_bias'
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj.bias'] = state_dict[k_bias][:dim]
items_to_add[prefix + 'k_proj.bias'] = state_dict[k_bias][dim:2*dim]
items_to_add[prefix + 'v_proj.bias'] = state_dict[k_bias][2*dim:]
keys_to_remove.append(prefix + 'in_proj_bias')
for k in keys_to_remove: for k in keys_to_remove:
del state_dict[k] del state_dict[k]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment