Commit 4c6b689e authored by Halil Akin's avatar Halil Akin Committed by Facebook Github Bot
Browse files

Remove in_proj_weight/in_proj_bias in multihead attention and fix the failing tests instead (#898)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/898

Pull Request resolved: https://github.com/pytorch/fairseq/pull/1333

Pull Request resolved: https://github.com/fairinternal/fairspeq/pull/11

This in_proj_weight and in_proj_bias properties are not the right way of providing backward compatibility, and it's causing other incompatibilities with the new Dynamic Quantization API. So, let's remove this, and properly fix the failing tests.

Reviewed By: myleott

Differential Revision: D18264129

fbshipit-source-id: fc1838657a60d914ca83c4e0f6add5ed8206ac54
parent 99c524c5
......@@ -63,14 +63,6 @@ class MultiheadAttention(nn.Module):
else:
self.enable_torch_version = False
@property
def in_proj_weight(self):
return torch.cat((self.q_proj.weight, self.k_proj.weight, self.v_proj.weight))
@property
def in_proj_bias(self):
return torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias))
def prepare_for_onnx_export_(self):
self.onnx_trace = True
......@@ -132,7 +124,8 @@ class MultiheadAttention(nn.Module):
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment