Commit 6d2e0831 authored by Guanheng Zhang's avatar Guanheng Zhang Committed by Facebook Github Bot
Browse files

Integrate torch.nn and fairseq MultiheadAttention (#772)

Summary:
Integrate torch.nn and fairseq MultiheadAttention modules. In the future, both libraries will be benefited from performance optimization together.

Under the following circumstances, the calculation of the MultiheadAttention will still remain in fairseq, including:
1. onnx trace
2. incremental state
3. static kv

We plan to gradually mitigate those capabilities to PyTorch's core library.

Faieseq users can user the attribute self.enable_torch_version to force the calculations in either torch or fairseq. We use the following script to ensure both versions yield the same results.

------------------------------------------------------------------------------------
```
import torch
from fairseq.modules import MultiheadAttention
import time

embed_dim = 64
kv_embed_dim = 1208
num_heads = 16
src_len = 20
tgt_len = 30
bsz = 10

model = MultiheadAttention(embed_dim, num_heads, kdim=kv_embed_dim, vdim=kv_embed_dim,
                           bias=True, add_bias_kv=True, add_zero_attn=True)

query = torch.rand((src_len, bsz, embed_dim))
key = torch.rand((src_len, bsz, kv_embed_dim))
value = torch.rand((src_len, bsz, kv_embed_dim))

attn_mask = torch.randint(0, 2, (src_len, src_len)).float()
attn_mask.masked_fill_(attn_mask == 0, float('-inf'))
attn_mask.masked_fill_(attn_mask > 0, float('0.0'))

seq_mask = torch.randint(0, 2, (1, src_len))
key_padding_mask = seq_mask
for i in range(bsz-1):
    key_padding_mask = torch.cat([key_padding_mask, seq_mask], axis=0)
key_padding_mask = key_padding_mask == 1

# Apply torch.nn version
model.enable_torch_version = True
torch_output, torch_weight = model(query, key, value, key_padding_mask=key_padding_mask, attn_mask=attn_mask)

# Apply fairseq version
model.enable_torch_version = False
fairseq_output, fairseq_weight = model(query, key, value, key_padding_mask=key_padding_mask, attn_mask=attn_mask)

print("torch and fairseq generate same results: outputs are same ? ",
      torch.allclose(torch_output, fairseq_output, atol=5e-6, rtol=1e-6),
      ", weights are same ? ",
      torch.allclose(torch_weight, fairseq_weight, atol=5e-6, rtol=1e-6)
)
```
------------------------------------------------------------------------------------
Expected results:
torch and fairseq generate same results: outputs are same ?  True , weights are same ?  True

------------------------------------------------------------------------------------
Similar performance is expected for both two versions. Using the following setup and have the initial performance benchmark results:

#########################
embed_dim = 32
kv_embed_dim = 32
num_heads = 4
src_len = 3
tgt_len = 2
bsz = 4
num_samples = 50000

#########################
torch-version MultiheadAttention cpu time: 0.46589  ms per iteration.
fairseq-version MultiheadAttention cpu time: 0.47861  ms per iteration.
torch-version MultiheadAttention gpu time: 0.82330  ms per iteration.
fairseq-version MultiheadAttention gpu time: 0.79410  ms per iteration.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/772

Reviewed By: myleott

Differential Revision: D16108450

Pulled By: zhangguanheng66

fbshipit-source-id: cd2eb5a6eeeab6c274999b7928c2af14fc211565
parent 417ecb4b
......@@ -67,6 +67,12 @@ class MultiheadAttention(nn.Module):
self.onnx_trace = False
self.enable_torch_version = False
if hasattr(F, "multi_head_attention_forward"):
self.enable_torch_version = True
else:
self.enable_torch_version = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
......@@ -101,6 +107,29 @@ class MultiheadAttention(nn.Module):
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
if self.qkv_same_dim:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
self.in_proj_weight,
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask)
else:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment