Commit 2d1da45c authored by Grigory Sizov's avatar Grigory Sizov Committed by Facebook GitHub Bot
Browse files

Use BetterTransfomer in WavLM Self-Attention (#2842)

Summary:
Closes T137506059

Replaces functional multi-head attention in `WavLMSelfAttention` with a module `torch.nn.MultiheadAttention`. The reason is that the latter uses native CPU/CUDA implementation ([BetterTransfomer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/)) under certain conditions, and can achieve significant speedup. It also simplifies the code in `WavLMSelfAttention`

Note: the definition of `bias` parameter in `WavLMSelfAttention.forward` has changed slightly, because in `torch.nn.MultiheadAttention` there is no parameter controlling presence of bias for projections of `k`, `v`, and `q` independently. In WavLM we only use `bias=True`, so it won't have any effect for users of WavLM or tests

Pull Request resolved: https://github.com/pytorch/audio/pull/2842

Reviewed By: nateanl

Differential Revision: D41186166

Pulled By: sgrigory

fbshipit-source-id: e791c68106ad89f96c1abf046de699cb8ec7b595
parent d73f4688
"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format. """Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format.
""" """
import logging import logging
from typing import Any, Dict
import torch
from torch.nn import Module from torch.nn import Module
from ..model import wav2vec2_model, Wav2Vec2Model, wavlm_model from ..model import wav2vec2_model, Wav2Vec2Model, wavlm_model
...@@ -73,27 +75,31 @@ def _build(config, original): ...@@ -73,27 +75,31 @@ def _build(config, original):
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
encoder_state_dict = wav2vec2.encoder.state_dict() encoder_state_dict = wav2vec2.encoder.state_dict()
if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model
encoder_state_dict = {rename_wavlm_key(x): encoder_state_dict[x] for x in encoder_state_dict.keys()} transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"])
imported.encoder.transformer.load_state_dict(encoder_state_dict) imported.encoder.transformer.load_state_dict(encoder_state_dict)
if is_for_ctc: if is_for_ctc:
imported.aux.load_state_dict(original.lm_head.state_dict()) imported.aux.load_state_dict(original.lm_head.state_dict())
return imported return imported
def rename_wavlm_key(key): def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int):
"""Rename weights and biases of linear transformations, since we define them directly in WavLMSelfAttention, """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and
as opposed to nesting them in Linear modules biases to align with the structure of ``torch.nn.MultiheadAttention``.
""" """
return ( for i in range(encoder_num_layers):
key.replace("k_proj.weight", "k_proj_weight") q_proj_bias = state.pop(f"layers.{i}.attention.q_proj.bias")
.replace("k_proj.bias", "k_proj_bias") k_proj_bias = state.pop(f"layers.{i}.attention.k_proj.bias")
.replace("q_proj.weight", "q_proj_weight") v_proj_bias = state.pop(f"layers.{i}.attention.v_proj.bias")
.replace("q_proj.bias", "q_proj_bias") q_proj_weight = state.pop(f"layers.{i}.attention.q_proj.weight")
.replace("v_proj.weight", "v_proj_weight") k_proj_weight = state.pop(f"layers.{i}.attention.k_proj.weight")
.replace("v_proj.bias", "v_proj_bias") v_proj_weight = state.pop(f"layers.{i}.attention.v_proj.weight")
.replace("out_proj.weight", "out_proj_weight") state[f"layers.{i}.attention.attention.in_proj_bias"] = torch.cat((q_proj_bias, k_proj_bias, v_proj_bias))
.replace("out_proj.bias", "out_proj_bias") state[f"layers.{i}.attention.attention.in_proj_weight"] = torch.cat(
) (q_proj_weight, k_proj_weight, v_proj_weight)
)
state[f"layers.{i}.attention.attention.out_proj.weight"] = state.pop(f"layers.{i}.attention.out_proj.weight")
state[f"layers.{i}.attention.attention.out_proj.bias"] = state.pop(f"layers.{i}.attention.out_proj.bias")
def import_huggingface_model(original: Module) -> Wav2Vec2Model: def import_huggingface_model(original: Module) -> Wav2Vec2Model:
......
...@@ -27,18 +27,19 @@ from typing import Optional, Tuple ...@@ -27,18 +27,19 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F
class WavLMSelfAttention(nn.Module): class WavLMSelfAttention(nn.Module):
"""Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`. """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`.
Wraps around ``torch.nn.MultiheadAttention``, creating relaive position embeddings and passing them to multi-headed
attention as a mask.
Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763 Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763
Args: Args:
embed_dim (int): Total dimension of the model. embed_dim (int): Total dimension of the model.
num_heads (int): The number of heads. num_heads (int): The number of heads.
dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``) dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``)
bias (bool, optional): If ``True``, add bias to projections for queries and values. (Default: ``True``) bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``)
has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding. has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding.
Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``) Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``)
num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``) num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``)
...@@ -59,10 +60,7 @@ class WavLMSelfAttention(nn.Module): ...@@ -59,10 +60,7 @@ class WavLMSelfAttention(nn.Module):
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets self.num_buckets = num_buckets
self.max_distance = max_distance self.max_distance = max_distance
...@@ -75,20 +73,7 @@ class WavLMSelfAttention(nn.Module): ...@@ -75,20 +73,7 @@ class WavLMSelfAttention(nn.Module):
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
# Define parameters of the linear transoformations. We don't use Linear to avoid problems with quantization. self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True)
# See also https://github.com/pytorch/audio/pull/2822#discussion_r1014431878
self.q_proj_weight, self.k_proj_weight, self.v_proj_weight, self.out_proj_weight = [
nn.Parameter(torch.zeros((embed_dim, embed_dim))) for _ in range(4)
]
self.k_proj_bias = nn.Parameter(torch.zeros(embed_dim))
if bias:
self.v_proj_bias, self.q_proj_bias, self.out_proj_bias = [
nn.Parameter(torch.zeros((embed_dim))) for _ in range(3)
]
else:
self.register_parameter("v_proj_bias", None)
self.register_parameter("q_proj_bias", None)
self.register_parameter("out_proj_bias", None)
self.gru_rel_pos = gru_rel_pos self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos: if self.gru_rel_pos:
...@@ -197,33 +182,7 @@ class WavLMSelfAttention(nn.Module): ...@@ -197,33 +182,7 @@ class WavLMSelfAttention(nn.Module):
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len))
bias_k = bias_v = None attn_output, _ = self.attention(
add_zero_attn = False query, query, query, key_padding_mask=key_padding_mask, attn_mask=attn_mask_rel_pos, need_weights=False
# multi_head_attention_forward expects query shape (seq_len, batch_size, embed_dim)
query = query.transpose(0, 1)
concat_bias = torch.cat((self.q_proj_bias, self.k_proj_bias, self.v_proj_bias))
attn_output, _ = F.multi_head_attention_forward(
query,
query,
query,
self.embed_dim,
self.num_heads,
torch.empty([0]),
concat_bias,
bias_k,
bias_v,
add_zero_attn,
self.dropout_module.p,
self.out_proj_weight,
self.out_proj_bias,
self.training,
key_padding_mask,
need_weights=False,
attn_mask=attn_mask_rel_pos,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
) )
attn_output = attn_output.transpose(0, 1) # Convert back to batch-first
return attn_output, position_bias return attn_output, position_bias
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment