Integrate torch.nn and fairseq MultiheadAttention (#772)
Summary:
Integrate torch.nn and fairseq MultiheadAttention modules. In the future, both libraries will be benefited from performance optimization together.
Under the following circumstances, the calculation of the MultiheadAttention will still remain in fairseq, including:
1. onnx trace
2. incremental state
3. static kv
We plan to gradually mitigate those capabilities to PyTorch's core library.
Faieseq users can user the attribute self.enable_torch_version to force the calculations in either torch or fairseq. We use the following script to ensure both versions yield the same results.
------------------------------------------------------------------------------------
```
import torch
from fairseq.modules import MultiheadAttention
import time
embed_dim = 64
kv_embed_dim = 1208
num_heads = 16
src_len = 20
tgt_len = 30
bsz = 10
model = MultiheadAttention(embed_dim, num_heads, kdim=kv_embed_dim, vdim=kv_embed_dim,
bias=True, add_bias_kv=True, add_zero_attn=True)
query = torch.rand((src_len, bsz, embed_dim))
key = torch.rand((src_len, bsz, kv_embed_dim))
value = torch.rand((src_len, bsz, kv_embed_dim))
attn_mask = torch.randint(0, 2, (src_len, src_len)).float()
attn_mask.masked_fill_(attn_mask == 0, float('-inf'))
attn_mask.masked_fill_(attn_mask > 0, float('0.0'))
seq_mask = torch.randint(0, 2, (1, src_len))
key_padding_mask = seq_mask
for i in range(bsz-1):
key_padding_mask = torch.cat([key_padding_mask, seq_mask], axis=0)
key_padding_mask = key_padding_mask == 1
# Apply torch.nn version
model.enable_torch_version = True
torch_output, torch_weight = model(query, key, value, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
# Apply fairseq version
model.enable_torch_version = False
fairseq_output, fairseq_weight = model(query, key, value, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
print("torch and fairseq generate same results: outputs are same ? ",
torch.allclose(torch_output, fairseq_output, atol=5e-6, rtol=1e-6),
", weights are same ? ",
torch.allclose(torch_weight, fairseq_weight, atol=5e-6, rtol=1e-6)
)
```
------------------------------------------------------------------------------------
Expected results:
torch and fairseq generate same results: outputs are same ? True , weights are same ? True
------------------------------------------------------------------------------------
Similar performance is expected for both two versions. Using the following setup and have the initial performance benchmark results:
#########################
embed_dim = 32
kv_embed_dim = 32
num_heads = 4
src_len = 3
tgt_len = 2
bsz = 4
num_samples = 50000
#########################
torch-version MultiheadAttention cpu time: 0.46589 ms per iteration.
fairseq-version MultiheadAttention cpu time: 0.47861 ms per iteration.
torch-version MultiheadAttention gpu time: 0.82330 ms per iteration.
fairseq-version MultiheadAttention gpu time: 0.79410 ms per iteration.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/772
Reviewed By: myleott
Differential Revision: D16108450
Pulled By: zhangguanheng66
fbshipit-source-id: cd2eb5a6eeeab6c274999b7928c2af14fc211565
Showing
Please register or sign in to comment