Commit 2d1da45c authored by Grigory Sizov's avatar Grigory Sizov Committed by Facebook GitHub Bot
Browse files

Use BetterTransfomer in WavLM Self-Attention (#2842)

Summary:
Closes T137506059

Replaces functional multi-head attention in `WavLMSelfAttention` with a module `torch.nn.MultiheadAttention`. The reason is that the latter uses native CPU/CUDA implementation ([BetterTransfomer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/)) under certain conditions, and can achieve significant speedup. It also simplifies the code in `WavLMSelfAttention`

Note: the definition of `bias` parameter in `WavLMSelfAttention.forward` has changed slightly, because in `torch.nn.MultiheadAttention` there is no parameter controlling presence of bias for projections of `k`, `v`, and `q` independently. In WavLM we only use `bias=True`, so it won't have any effect for users of WavLM or tests

Pull Request resolved: https://github.com/pytorch/audio/pull/2842

Reviewed By: nateanl

Differential Revision: D41186166

Pulled By: sgrigory

fbshipit-source-id: e791c68106ad89f96c1abf046de699cb8ec7b595
parent d73f4688
"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format.
"""
import logging
from typing import Any, Dict
import torch
from torch.nn import Module
from ..model import wav2vec2_model, Wav2Vec2Model, wavlm_model
......@@ -73,28 +75,32 @@ def _build(config, original):
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
encoder_state_dict = wav2vec2.encoder.state_dict()
if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model
encoder_state_dict = {rename_wavlm_key(x): encoder_state_dict[x] for x in encoder_state_dict.keys()}
transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"])
imported.encoder.transformer.load_state_dict(encoder_state_dict)
if is_for_ctc:
imported.aux.load_state_dict(original.lm_head.state_dict())
return imported
def rename_wavlm_key(key):
"""Rename weights and biases of linear transformations, since we define them directly in WavLMSelfAttention,
as opposed to nesting them in Linear modules
def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int):
"""Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and
biases to align with the structure of ``torch.nn.MultiheadAttention``.
"""
return (
key.replace("k_proj.weight", "k_proj_weight")
.replace("k_proj.bias", "k_proj_bias")
.replace("q_proj.weight", "q_proj_weight")
.replace("q_proj.bias", "q_proj_bias")
.replace("v_proj.weight", "v_proj_weight")
.replace("v_proj.bias", "v_proj_bias")
.replace("out_proj.weight", "out_proj_weight")
.replace("out_proj.bias", "out_proj_bias")
for i in range(encoder_num_layers):
q_proj_bias = state.pop(f"layers.{i}.attention.q_proj.bias")
k_proj_bias = state.pop(f"layers.{i}.attention.k_proj.bias")
v_proj_bias = state.pop(f"layers.{i}.attention.v_proj.bias")
q_proj_weight = state.pop(f"layers.{i}.attention.q_proj.weight")
k_proj_weight = state.pop(f"layers.{i}.attention.k_proj.weight")
v_proj_weight = state.pop(f"layers.{i}.attention.v_proj.weight")
state[f"layers.{i}.attention.attention.in_proj_bias"] = torch.cat((q_proj_bias, k_proj_bias, v_proj_bias))
state[f"layers.{i}.attention.attention.in_proj_weight"] = torch.cat(
(q_proj_weight, k_proj_weight, v_proj_weight)
)
state[f"layers.{i}.attention.attention.out_proj.weight"] = state.pop(f"layers.{i}.attention.out_proj.weight")
state[f"layers.{i}.attention.attention.out_proj.bias"] = state.pop(f"layers.{i}.attention.out_proj.bias")
def import_huggingface_model(original: Module) -> Wav2Vec2Model:
"""Builds :class:`Wav2Vec2Model` from the corresponding model object of
......
......@@ -27,18 +27,19 @@ from typing import Optional, Tuple
import torch
from torch import nn, Tensor
from torch.nn import functional as F
class WavLMSelfAttention(nn.Module):
"""Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`.
Wraps around ``torch.nn.MultiheadAttention``, creating relaive position embeddings and passing them to multi-headed
attention as a mask.
Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763
Args:
embed_dim (int): Total dimension of the model.
num_heads (int): The number of heads.
dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``)
bias (bool, optional): If ``True``, add bias to projections for queries and values. (Default: ``True``)
bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``)
has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding.
Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``)
num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``)
......@@ -59,10 +60,7 @@ class WavLMSelfAttention(nn.Module):
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
......@@ -75,20 +73,7 @@ class WavLMSelfAttention(nn.Module):
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
# Define parameters of the linear transoformations. We don't use Linear to avoid problems with quantization.
# See also https://github.com/pytorch/audio/pull/2822#discussion_r1014431878
self.q_proj_weight, self.k_proj_weight, self.v_proj_weight, self.out_proj_weight = [
nn.Parameter(torch.zeros((embed_dim, embed_dim))) for _ in range(4)
]
self.k_proj_bias = nn.Parameter(torch.zeros(embed_dim))
if bias:
self.v_proj_bias, self.q_proj_bias, self.out_proj_bias = [
nn.Parameter(torch.zeros((embed_dim))) for _ in range(3)
]
else:
self.register_parameter("v_proj_bias", None)
self.register_parameter("q_proj_bias", None)
self.register_parameter("out_proj_bias", None)
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True)
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
......@@ -197,33 +182,7 @@ class WavLMSelfAttention(nn.Module):
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len))
bias_k = bias_v = None
add_zero_attn = False
# multi_head_attention_forward expects query shape (seq_len, batch_size, embed_dim)
query = query.transpose(0, 1)
concat_bias = torch.cat((self.q_proj_bias, self.k_proj_bias, self.v_proj_bias))
attn_output, _ = F.multi_head_attention_forward(
query,
query,
query,
self.embed_dim,
self.num_heads,
torch.empty([0]),
concat_bias,
bias_k,
bias_v,
add_zero_attn,
self.dropout_module.p,
self.out_proj_weight,
self.out_proj_bias,
self.training,
key_padding_mask,
need_weights=False,
attn_mask=attn_mask_rel_pos,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
attn_output, _ = self.attention(
query, query, query, key_padding_mask=key_padding_mask, attn_mask=attn_mask_rel_pos, need_weights=False
)
attn_output = attn_output.transpose(0, 1) # Convert back to batch-first
return attn_output, position_bias
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment