Use BetterTransfomer in WavLM Self-Attention (#2842)
Summary: Closes T137506059 Replaces functional multi-head attention in `WavLMSelfAttention` with a module `torch.nn.MultiheadAttention`. The reason is that the latter uses native CPU/CUDA implementation ([BetterTransfomer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/)) under certain conditions, and can achieve significant speedup. It also simplifies the code in `WavLMSelfAttention` Note: the definition of `bias` parameter in `WavLMSelfAttention.forward` has changed slightly, because in `torch.nn.MultiheadAttention` there is no parameter controlling presence of bias for projections of `k`, `v`, and `q` independently. In WavLM we only use `bias=True`, so it won't have any effect for users of WavLM or tests Pull Request resolved: https://github.com/pytorch/audio/pull/2842 Reviewed By: nateanl Differential Revision: D41186166 Pulled By: sgrigory fbshipit-source-id: e791c68106ad89f96c1abf046de699cb8ec7b595
Showing
Please register or sign in to comment