Use scaled_dot_product_attention in WavLM attention (#3252)
Summary: Fix https://github.com/pytorch/audio/issues/3219. `torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`. Pull Request resolved: https://github.com/pytorch/audio/pull/3252 Reviewed By: mthrok Differential Revision: D44798634 Pulled By: nateanl fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
Showing
Please register or sign in to comment