Commit bd76d3d7 authored by Grigory Sizov's avatar Grigory Sizov Committed by Facebook GitHub Bot
Browse files

Add WavLM model (#2822)

Summary:
Closes T136364380

Added [WavLM Model](https://github.com/microsoft/UniSpeech/tree/main/WavLM):
- Added `WavLMSelfAttention` class (from [original implementation](https://github.com/microsoft/UniSpeech/blob/2e9dde8bf815a5f5fd958e3435e5641f59f96928/WavLM/modules.py)) and adjusted existing Encoder and Transformer classes to be compatible with it
- Added factory functions `wavlm_model`, `wavlm_base`, `wavlm_large` to `models/wav2vec2/model.py`
- Added bundles for base and large models to pipelines. **TODO**: pre-trained model weights are not yet uploaded to `download.pytorch.org`, permissions not granted yet.

## Tests
- Expanded HuggingFace integration tests to cover WavLM. For there tests, added JSON configs for base and large models from HF ([base](https://huggingface.co/microsoft/wavlm-base/blob/main/config.json), [large](https://huggingface.co/microsoft/wavlm-large/blob/main/config.json)) into test assets
- Expanded TorchScript and quantization tests to cover WavLM

## Comments
There are a few workarounds I had to introduce:
- Quantization tests for WavLM were breaking down at [`torch.cat`](https://github.com/pytorch/audio/pull/2822/files#diff-6f1486901c94320ec0610a460dc674638fab9d104a61564ff7b59353a8b8547cR466) ~~until I excluded the arguments of `torch.cat` from quantization [here](https://github.com/pytorch/audio/pull/2822/files#diff-6f1486901c94320ec0610a460dc674638fab9d104a61564ff7b59353a8b8547cR368-R369). I haven't found a better way to fix it, let me know if there is one~~ The reason for this seems to be that quantization replaces `.bias` and `.weight` attributes of a `Linear` module with methods. Since we are using weights and biases directly, the code was break. The final solution suggested by nateanl was to define attention weights and biases directly in `WavLMSelfAttention`, skipping the `Linear` layers
- ~~WavLM uses position embedding in the first layer of encoder, but not in the subsequent ones.  So [UniSpeech](https://github.com/microsoft/UniSpeech/blob/2e9dde8bf815a5f5fd958e3435e5641f59f96928/WavLM/modules.py#L342) and [HF](https://github.com/huggingface/transformers/blob/b047472650cba259621549ac27b18fd2066ce18e/src/transformers/models/wavlm/modeling_wavlm.py#L441-L442) implementations only create this embedding module in the layers where it's used. However, we can't do this here because it breaks TorchScript. So as a solution I add a dummy `Identity` module to `WavLMSelfAttention` when the actual embedding is not needed: [here](https://github.com/pytorch/audio/pull/2822/files#diff-6f1486901c94320ec0610a460dc674638fab9d104a61564ff7b59353a8b8547cR361-R368).~~ Thanks nateanl for resolving this!
- I had to add dummy `position_bias` and `key_padding_mask` arguments to `SelfAttention.forward` to make TorchScript tests pass. Since both `SelfAttention` and `WavLMSelfAttention` are called from `EncoderLayer`, they need to have compatible signatures. Having a variable number of arguments with `**kwargs` or checking object class doesn't seem to work with TorchScript, so I instead made both types of attention accept `position_bias` and `key_padding_mask` arguments.

Nit: do we still need to specify `__all__` if there are no wildcard imports in `__init__.py`, e.g. in `torchaudio/models/__init__.py`?

Pull Request resolved: https://github.com/pytorch/audio/pull/2822

Reviewed By: nateanl

Differential Revision: D41121855

Pulled By: sgrigory

fbshipit-source-id: 9f4f787e5810010de4e74cb704063a26c66767d7
parent 5e2507f5
......@@ -439,3 +439,19 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
journal={arXiv preprint arXiv:1805.10190},
year={2018}
}
@article{chen2022wavlm,
title={Wavlm: Large-scale self-supervised pre-training for full stack speech processing},
author={Chen, Sanyuan and Wang, Chengyi and Chen, Zhengyang and Wu, Yu and Liu, Shujie and Chen, Zhuo and Li, Jinyu and Kanda, Naoyuki and Yoshioka, Takuya and Xiao, Xiong and others},
journal={IEEE Journal of Selected Topics in Signal Processing},
volume={16},
number={6},
pages={1505--1518},
year={2022},
publisher={IEEE}
}
@inproceedings{GigaSpeech2021,
title={GigaSpeech: An Evolving, Multi-domain ASR Corpus with 10,000 Hours of Transcribed Audio},
booktitle={Proc. Interspeech 2021},
year=2021,
author={Guoguo Chen and Shuzhou Chai and Guanbo Wang and Jiayu Du and Wei-Qiang Zhang and Chao Weng and Dan Su and Daniel Povey and Jan Trmal and Junbo Zhang and Mingjie Jin and Sanjeev Khudanpur and Shinji Watanabe and Shuaijiang Zhao and Wei Zou and Xiangang Li and Xuchen Yao and Yongqing Wang and Yujun Wang and Zhao You and Zhiyong Yan}
}
{
"activation_dropout": 0.0,
"adapter_kernel_size": 3,
"adapter_stride": 2,
"add_adapter": false,
"apply_spec_augment": true,
"architectures": [
"WavLMModel"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"classifier_proj_size": 256,
"codevector_dim": 256,
"contrastive_logits_temperature": 0.1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"diversity_loss_weight": 0.1,
"do_stable_layer_norm": false,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_norm": "group",
"feat_proj_dropout": 0.1,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"freeze_feat_extract_train": true,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"layerdrop": 0.05,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.05,
"mask_time_selection": "static",
"max_bucket_distance": 800,
"model_type": "wavlm",
"no_mask_channel_overlap": false,
"no_mask_time_overlap": false,
"num_adapter_layers": 3,
"num_attention_heads": 12,
"num_buckets": 320,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_ctc_classes": 80,
"num_feat_extract_layers": 7,
"num_hidden_layers": 12,
"num_negatives": 100,
"output_hidden_size": 768,
"pad_token_id": 0,
"proj_codevector_dim": 256,
"tokenizer_class": "Wav2Vec2CTCTokenizer",
"torch_dtype": "float32",
"transformers_version": "4.15.0.dev0",
"use_weighted_layer_sum": false,
"vocab_size": 32
}
{
"activation_dropout": 0.0,
"adapter_kernel_size": 3,
"adapter_stride": 2,
"add_adapter": false,
"apply_spec_augment": true,
"architectures": [
"WavLMModel"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"classifier_proj_size": 256,
"codevector_dim": 768,
"contrastive_logits_temperature": 0.1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"diversity_loss_weight": 0.1,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.075,
"mask_time_selection": "static",
"max_bucket_distance": 800,
"model_type": "wavlm",
"num_adapter_layers": 3,
"num_attention_heads": 16,
"num_buckets": 320,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_ctc_classes": 80,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"num_negatives": 100,
"output_hidden_size": 1024,
"pad_token_id": 0,
"proj_codevector_dim": 768,
"replace_prob": 0.5,
"tokenizer_class": "Wav2Vec2CTCTokenizer",
"torch_dtype": "float32",
"transformers_version": "4.15.0.dev0",
"use_weighted_layer_sum": false,
"vocab_size": 32
}
......@@ -16,6 +16,7 @@ from .case_utils import (
TempDirMixin,
TestBaseMixin,
TorchaudioTestCase,
zip_equal,
)
from .data_utils import get_asset_path, get_sinusoid, get_spectrogram, get_whitenoise
from .func_utils import torch_script
......@@ -57,4 +58,5 @@ __all__ = [
"get_image",
"rgb_to_gray",
"rgb_to_yuv_ccir",
"zip_equal",
]
......@@ -6,6 +6,7 @@ import sys
import tempfile
import time
import unittest
from itertools import zip_longest
import torch
import torchaudio
......@@ -245,3 +246,16 @@ skipIfPy310 = _skipIf(
),
key="ON_PYTHON_310",
)
def zip_equal(*iterables):
"""With the regular Python `zip` function, if one iterable is longer than the other,
the remainder portions are ignored.This is resolved in Python 3.10 where we can use
`strict=True` in the `zip` function
From https://github.com/pytorch/text/blob/c047efeba813ac943cb8046a49e858a8b529d577/test/torchtext_unittest/common/case_utils.py#L45-L54 # noqa: E501
"""
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo
......@@ -2,9 +2,9 @@ import json
import torch
from parameterized import parameterized
from torchaudio.models.wav2vec2 import wav2vec2_base, wav2vec2_large, wav2vec2_large_lv60k
from torchaudio.models.wav2vec2 import wav2vec2_base, wav2vec2_large, wav2vec2_large_lv60k, wavlm_base, wavlm_large
from torchaudio.models.wav2vec2.utils import import_huggingface_model
from torchaudio_unittest.common_utils import get_asset_path, skipIfNoModule, TorchaudioTestCase
from torchaudio_unittest.common_utils import get_asset_path, skipIfNoModule, TorchaudioTestCase, zip_equal
def _load_config(*paths):
......@@ -22,6 +22,8 @@ HF_LARGE = _load_config("wav2vec2-large")
HF_LARGE_LV60 = _load_config("wav2vec2-large-lv60")
HF_LARGE_XLSR_53 = _load_config("wav2vec2-large-xlsr-53")
HF_BASE_10K_VOXPOPULI = _load_config("wav2vec2-base-10k-voxpopuli")
HF_BASE_WAVLM = _load_config("wavlm-base")
HF_LARGE_WAVLM = _load_config("wavlm-large")
# Finetuned
HF_BASE_960H = _load_config("wav2vec2-base-960h")
HF_LARGE_960H = _load_config("wav2vec2-large-960h")
......@@ -50,6 +52,13 @@ FINETUNE_CONFIGS = parameterized.expand(
],
name_func=_name_func,
)
WAVLM_CONFIGS = parameterized.expand(
[
(HF_BASE_WAVLM, wavlm_base),
(HF_LARGE_WAVLM, wavlm_large),
],
name_func=_name_func,
)
@skipIfNoModule("transformers")
......@@ -68,12 +77,14 @@ class TestHFIntegration(TorchaudioTestCase):
# However, somehow, once "transformers" is imported, `is_module_available`
# starts to fail. Therefore, we defer importing "transformers" until
# the actual tests are started.
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model, WavLMConfig, WavLMModel
if config["architectures"] == ["Wav2Vec2Model"]:
return Wav2Vec2Model(Wav2Vec2Config(**config))
if config["architectures"] == ["Wav2Vec2ForCTC"]:
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
if config["architectures"] == ["WavLMModel"]:
return WavLMModel(WavLMConfig(**config))
raise ValueError(f'Unexpected arch: {config["architectures"]}')
def _test_import_pretrain(self, original, imported, config):
......@@ -97,9 +108,8 @@ class TestHFIntegration(TorchaudioTestCase):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
(ref,) = original_(x, attention_mask=mask, output_attentions=False)
hyp = imported_(x, mask)
hyp, _ = imported_(x, mask) # Ignore returned position_bias, which is always None for Wav2Vec2 and HuBERT
self.assertEqual(ref, hyp)
# The whole Encoder Transformer
b, l, e = 16, 3, config["hidden_size"]
......@@ -115,11 +125,6 @@ class TestHFIntegration(TorchaudioTestCase):
hyp = imported.aux(x)
self.assertEqual(ref, hyp)
# The whole model without mask
x = torch.randn(3, 1024)
ref = original(x).logits
hyp, _ = imported(x)
self.assertEqual(ref, hyp)
# The whole model without mask
batch_size, num_frames = 3, 1024
x = torch.randn(batch_size, num_frames)
ref = original(x).logits
......@@ -159,6 +164,51 @@ class TestHFIntegration(TorchaudioTestCase):
self._test_import_pretrain(original.wav2vec2, imported, config)
self._test_import_finetune(original, imported, config)
@WAVLM_CONFIGS
def test_import_pretrain_wavlm(self, config, _):
"""WavLM models from HF transformers can be imported and yield the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
# FeatureExtractor
x = torch.randn(3, 1024)
ref = original.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config["conv_dim"][-1])
ref = original.feature_projection(x)[0]
hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config["hidden_size"])
ref = original.encoder.pos_conv_embed(x)
hyp = imported.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
position_bias = None
position_bias_imp = None
assert len(original.encoder.layers) > 0
for original_, imported_ in zip_equal(original.encoder.layers, imported.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, l) > 0.5 # HF WaveLM model expects the mask to be binary
# HF WaveLM model (original_) takes in "attention mask" but actually uses it as key padding mask:
# https://github.com/huggingface/transformers/blob/b047472650cba259621549ac27b18fd2066ce18e/src/transformers/models/wavlm/modeling_wavlm.py#L495
ref, position_bias = original_(x, attention_mask=mask, output_attentions=False, position_bias=position_bias)
hyp, position_bias_imp = imported_(x, key_padding_mask=mask.ne(1), position_bias=position_bias_imp)
# Masked-out elements are undefined in the output
ref_filled = ref.masked_fill(~mask.unsqueeze(2), 0)
hyp_filled = hyp.masked_fill(~mask.unsqueeze(2), 0)
self.assertEqual(ref_filled, hyp_filled)
# The whole Encoder Transformer
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
ref = original.encoder(x).last_hidden_state
hyp = imported.encoder.transformer(x)
self.assertEqual(ref, hyp)
def _test_recreate(self, imported, reloaded, config):
# FeatureExtractor
x = torch.randn(3, 1024)
......@@ -221,3 +271,50 @@ class TestHFIntegration(TorchaudioTestCase):
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
@WAVLM_CONFIGS
def test_recreate_wavlm(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func()
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
# FeatureExtractor
x = torch.randn(3, 1024)
ref, _ = imported.feature_extractor(x, None)
hyp, _ = reloaded.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config["conv_dim"][-1])
ref = imported.encoder.feature_projection(x)
hyp = reloaded.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config["hidden_size"])
ref = imported.encoder.transformer.pos_conv_embed(x)
hyp = reloaded.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
# Encoder Transformer Layer
position_bias_ref = None
position_bias_hyp = None
for imported_, reloaded_ in zip(imported.encoder.transformer.layers, reloaded.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, l) > 0.5 # HugginFace WaveLM expects the mask to be binary
ref, position_bias_ref = imported_(x, key_padding_mask=mask, position_bias=position_bias_ref)
hyp, position_bias_hyp = reloaded_(x, key_padding_mask=mask, position_bias=position_bias_hyp)
self.assertEqual(ref, hyp)
# The whole Encoder Transformer
# TODO: Add mask pattern. Expected mask shapes and values are different.
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
ref = imported.encoder.transformer(x)
hyp = reloaded.encoder.transformer(x)
self.assertEqual(ref, hyp)
# The whole model
x = torch.randn(3, 1024)
ref, _ = imported(x)
hyp, _ = reloaded(x)
self.assertEqual(ref, hyp)
......@@ -11,6 +11,8 @@ from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
wavlm_base,
wavlm_large,
)
from torchaudio_unittest.common_utils import skipIfNoCuda, skipIfNoQengine, torch_script, TorchaudioTestCase
......@@ -37,6 +39,14 @@ factory_funcs = parameterized.expand(
name_func=_name_func,
)
factory_funcs_wavlm = parameterized.expand(
[
(wavlm_base,),
(wavlm_large,),
],
name_func=_name_func,
)
class TestWav2Vec2Model(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype):
......@@ -263,3 +273,128 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func(aux_num_out=32))
class TestWavLMModel(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
batch_size, num_frames = 3, 1024
waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
model(waveforms)
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = wavlm_base()
self._smoke_test(model, torch.device("cpu"), dtype)
model = wavlm_base(aux_num_out=32)
self._smoke_test(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = wavlm_base()
self._smoke_test(model, torch.device("cuda"), dtype)
model = wavlm_base(aux_num_out=32)
self._smoke_test(model, torch.device("cuda"), dtype)
def _test_batch_consistency(self, model):
model.eval()
batch_size, max_frames = 5, 5 * 1024
waveforms = torch.randn(batch_size, max_frames)
# Batch process
batch_logits, _ = model(waveforms)
# Par-sample process
for i in range(batch_size):
single_logit, _ = model(waveforms[i : i + 1])
batch_logit = batch_logits[i : i + 1]
# Convert to probability so that it's easier to interpretate the diff
single_prob = F.softmax(single_logit, dim=2)
batch_prob = F.softmax(batch_logit, dim=2)
# We allow max atol=0.005 -> 0.5%
self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0)
@factory_funcs_wavlm
def test_pretrain_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close"""
self._test_batch_consistency(factory_func())
@factory_funcs_wavlm
def test_finetune_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close"""
self._test_batch_consistency(factory_func(aux_num_out=32))
def _test_torchscript(self, model):
model.eval()
batch_size, num_frames = 3, 1024
waveforms = torch.randn(batch_size, num_frames)
# Compute results with original model
ref_out, ref_len = model(waveforms)
# Compute results with scripted model
scripted = torch_script(model)
hyp_out, hyp_len = scripted(waveforms)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@factory_funcs_wavlm
def test_pretrain_torchscript(self, factory_func):
"""WavLM model should be scriptable"""
self._test_torchscript(factory_func())
@factory_funcs_wavlm
def test_finetune_torchscript(self, factory_func):
"""WavLM model with a head should be scriptable"""
self._test_torchscript(factory_func(aux_num_out=32))
def _test_quantize_smoke_test(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
waveforms = torch.randn(batch_size, num_frames)
_, _ = quantized(waveforms)
@factory_funcs_wavlm
@skipIfNoQengine
def test_quantize(self, factory_func):
"""WavLM should support basic quantization"""
self._test_quantize_smoke_test(factory_func(aux_num_out=32))
def _test_quantize_torchscript(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
waveforms = torch.randn(batch_size, num_frames)
ref_out, ref_len = quantized(waveforms)
# Script
scripted = torch_script(quantized)
hyp_out, hyp_len = scripted(waveforms)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@factory_funcs_wavlm
@skipIfNoQengine
def test_quantize_torchscript(self, factory_func):
"""Quantized WavLM model should be scriptable"""
self._test_quantize_torchscript(factory_func(aux_num_out=32))
......@@ -21,6 +21,9 @@ from .wav2vec2 import (
wav2vec2_large_lv60k,
wav2vec2_model,
Wav2Vec2Model,
wavlm_base,
wavlm_large,
wavlm_model,
)
from .wavernn import WaveRNN
......@@ -33,6 +36,9 @@ __all__ = [
"DeepSpeech",
"Wav2Vec2Model",
"HuBERTPretrainModel",
"wavlm_model",
"wavlm_base",
"wavlm_large",
"wav2vec2_model",
"wav2vec2_base",
"wav2vec2_large",
......
......@@ -13,11 +13,17 @@ from .model import (
wav2vec2_large_lv60k,
wav2vec2_model,
Wav2Vec2Model,
wavlm_base,
wavlm_large,
wavlm_model,
)
__all__ = [
"Wav2Vec2Model",
"HuBERTPretrainModel",
"wavlm_model",
"wavlm_base",
"wavlm_large",
"wav2vec2_model",
"wav2vec2_base",
"wav2vec2_large",
......
......@@ -5,6 +5,8 @@ import torch
from torch import nn, Tensor
from torch.nn import Module, Parameter
from .wavlm_attention import WavLMSelfAttention
_LG = logging.getLogger(__name__)
......@@ -243,7 +245,7 @@ class SelfAttention(Module):
embed_dim (int): Total dimension of the model.
num_heads (int): The number of heads.
dropout (float, optional):
Dropout probabiliry on attn_output_weights. Default: ``0.0``
Dropout probability on attn_output_weights. Default: ``0.0``
"""
def __init__(
......@@ -273,15 +275,21 @@ class SelfAttention(Module):
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
) -> Tensor:
position_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
attention_mask (Tensor or None, optional):
attention_mask (Tensor or ``None``, optional):
shape: ``[batch_size, 1, sequence_length, sequence_length]``
position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`.
key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with
:py:class:`WavLMSelfAttention`.
Returns:
Tensor: The resulting tensor. shape: ``[batch, sequence_length, embed_dim]``
(Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility
with :py:class:`WavLMSelAttention`).
Attention output shape: ``[batch, sequence_length, embed_dim]``.
"""
if x.ndim != 3 or x.shape[2] != self.embed_dim:
raise ValueError(
......@@ -314,7 +322,7 @@ class SelfAttention(Module):
output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
output = self.out_proj(output)
return output
return output, None # Necessary for compatibility with WavLMSelAttention
class FeedForward(Module):
......@@ -371,19 +379,32 @@ class EncoderLayer(Module):
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
position_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): shape: `(batch, sequence_length, embed_dim)`
attention_mask (Tensor or None, optional):
shape: `(batch, 1, sequence_length, sequence_length)`
x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``.
attention_mask (Tensor or ``None``, optional): attention mask
of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``)
position_bias (Tensor or ``None``, optional): position bias of shape
``(batch_size * num_heads, src_len, src_len)``.
Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``)
key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``.
Only used for WavLM model, ignored otherwise. (Default: ``None``)
Returns:
(x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model,
``None`` otherwise.
"""
residual = x
if self.layer_norm_first:
x = self.layer_norm(x)
x = self.attention(x, attention_mask)
x, position_bias = self.attention(
x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask
)
x = self.dropout(x)
x = residual + x
......@@ -392,7 +413,7 @@ class EncoderLayer(Module):
else:
x = self.layer_norm(x)
x = self.final_layer_norm(x + self.feed_forward(x))
return x
return x, position_bias
class Transformer(Module):
......@@ -425,15 +446,15 @@ class Transformer(Module):
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
position_bias: Optional[Tensor] = None,
) -> Tensor:
x = self._preprocess(x)
for layer in self.layers:
if not (self.training and torch.rand(1).item() <= self.layer_drop):
x = layer(x, attention_mask)
x, position_bias = layer(x, attention_mask, position_bias=position_bias)
if not self.layer_norm_first:
x = self.layer_norm(x)
return x
def get_intermediate_outputs(
......@@ -449,7 +470,7 @@ class Transformer(Module):
ret: List[Tensor] = []
x = self._preprocess(x)
for layer in self.layers:
x = layer(x, attention_mask)
x, _ = layer(x, attention_mask) # Ignore position_bias
ret.append(x)
if num_layers is not None and len(ret) >= num_layers:
return ret
......@@ -752,6 +773,85 @@ def _get_encoder(
return Encoder(feature_projection, transformer)
def _get_wavlm_encoder(
in_features: int,
embed_dim: int,
dropout_input: float,
pos_conv_kernel: int,
pos_conv_groups: int,
num_layers: int,
num_heads: int,
num_buckets: int,
max_distance: int,
attention_dropout: float,
ff_interm_features: int,
ff_interm_dropout: float,
dropout: float,
layer_norm_first: bool,
layer_drop: float,
) -> Encoder:
"""
Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are
the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder
is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and
`max_distance`.
Args:
in_features (int): See :py:func:`_get_encoder`.
embed_dim (int): See :py:func:`_get_encoder`.
dropout_input (float): See :py:func:`_get_encoder`.
pos_conv_kernel (int): See :py:func:`_get_encoder`.
pos_conv_groups (int): See :py:func:`_get_encoder`.
num_layers (int): See :py:func:`_get_encoder`.
num_heads (int): See :py:func:`_get_encoder`.
num_buckets (int): Number of buckets for relative position embedding.
max_distance (int): Maximum distance for relative position embedding.
attention_dropout (float): See :py:func:`_get_encoder`.
ff_interm_features (int): See :py:func:`_get_encoder`.
ff_interm_dropout (float): See :py:func:`_get_encoder`.
dropout (float): See :py:func:`_get_encoder`.
layer_norm_first (bool): See :py:func:`_get_encoder`.
layer_drop (float): See :py:func:`_get_encoder`.
"""
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
# Original impl
# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
encoder_layers = nn.ModuleList()
for i in range(num_layers):
attention = WavLMSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_buckets=num_buckets,
max_distance=max_distance,
dropout=attention_dropout,
has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer.
)
feed_forward = FeedForward(
io_features=embed_dim,
intermediate_features=ff_interm_features,
intermediate_dropout=ff_interm_dropout,
output_dropout=dropout,
)
encoder_layers.append(
EncoderLayer(
attention=attention,
dropout=dropout,
layer_norm_first=layer_norm_first,
feed_forward=feed_forward,
)
)
transformer = Transformer(
pos_conv_embed=pos_conv,
dropout=dropout,
layers=encoder_layers,
layer_norm_first=not layer_norm_first,
layer_drop=layer_drop,
)
return Encoder(feature_projection, transformer)
def _compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[Tensor],
......
......@@ -1216,3 +1216,219 @@ def hubert_pretrain_xlarge(
final_dim=1024,
feature_grad_mult=feature_grad_mult,
)
def wavlm_model(
extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
extractor_conv_bias: bool,
encoder_embed_dim: int,
encoder_projection_dropout: float,
encoder_pos_conv_kernel: int,
encoder_pos_conv_groups: int,
encoder_num_layers: int,
encoder_num_heads: int,
encoder_num_buckets: int,
encoder_max_distance: int,
encoder_attention_dropout: float,
encoder_ff_interm_features: int,
encoder_ff_interm_dropout: float,
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
aux_num_out: Optional[int],
) -> Wav2Vec2Model:
"""Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is
:class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning
as in :py:func:`wav2vec2_model` so please refer there for documentation.
Args:
extractor_mode (str): Operation mode of feature extractor.
See :py:func:`wav2vec2_model`.
extractor_conv_layer_config (list of integer tuples or None):
See :py:func:`wav2vec2_model`.
extractor_conv_bias (bool):
See :py:func:`wav2vec2_model`.
encoder_embed_dim (int):
See :py:func:`wav2vec2_model`.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_pos_conv_kernel (int):
See :py:func:`wav2vec2_model`.
encoder_pos_conv_groups (int):
See :py:func:`wav2vec2_model`.
encoder_num_layers (int):
See :py:func:`wav2vec2_model`.
encoder_num_heads (int):
See :py:func:`wav2vec2_model`.
encoder_num_buckets (int):
Number of buckets for relative position embedding.
encoder_max_distance (int):
Maximum distance for relative position embedding.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_features (int):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_norm_first (bool):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
aux_num_out (int or None):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
feature_extractor = components._get_feature_extractor(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias
)
encoder = components._get_wavlm_encoder(
in_features=extractor_conv_layer_config[-1][0],
embed_dim=encoder_embed_dim,
dropout_input=encoder_projection_dropout,
pos_conv_kernel=encoder_pos_conv_kernel,
pos_conv_groups=encoder_pos_conv_groups,
num_layers=encoder_num_layers,
num_heads=encoder_num_heads,
num_buckets=encoder_num_buckets,
max_distance=encoder_max_distance,
attention_dropout=encoder_attention_dropout,
ff_interm_features=encoder_ff_interm_features,
ff_interm_dropout=encoder_ff_interm_dropout,
dropout=encoder_dropout,
layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop,
)
aux = None
if aux_num_out is not None:
aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
return Wav2Vec2Model(feature_extractor, encoder, aux)
def wavlm_base(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
:class:`~torchaudio.models.Wav2Vec2Model`.
Args:
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
aux_num_out (int, optional):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
return wavlm_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_num_buckets=320,
encoder_max_distance=800,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=aux_num_out,
)
def wavlm_large(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
:class:`~torchaudio.models.Wav2Vec2Model`.
Args:
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
aux_num_out (int, optional):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
return wavlm_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_num_buckets=320,
encoder_max_distance=800,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=aux_num_out,
)
......@@ -4,7 +4,7 @@ import logging
from torch.nn import Module
from ..model import wav2vec2_model, Wav2Vec2Model
from ..model import wav2vec2_model, Wav2Vec2Model, wavlm_model
_LG = logging.getLogger(__name__)
......@@ -30,23 +30,72 @@ def _get_config(cfg):
return config
def _get_config_wavlm(cfg):
config = {
"extractor_mode": f"{cfg.feat_extract_norm}_norm",
"extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
"extractor_conv_bias": cfg.conv_bias,
"encoder_embed_dim": cfg.hidden_size,
"encoder_projection_dropout": cfg.feat_proj_dropout,
"encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings,
"encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups,
"encoder_num_layers": cfg.num_hidden_layers,
"encoder_num_heads": cfg.num_attention_heads,
"encoder_num_buckets": cfg.num_buckets,
"encoder_max_distance": cfg.max_bucket_distance,
"encoder_attention_dropout": cfg.attention_dropout,
"encoder_ff_interm_features": cfg.intermediate_size,
"encoder_ff_interm_dropout": cfg.activation_dropout,
"encoder_dropout": cfg.hidden_dropout,
"encoder_layer_norm_first": cfg.do_stable_layer_norm,
"encoder_layer_drop": cfg.layerdrop,
}
return config
def _build(config, original):
if original.__class__.__name__ == "Wav2Vec2ForCTC":
is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"]
if is_for_ctc:
aux_num_out = original.config.vocab_size
wav2vec2 = original.wav2vec2
else:
_LG.warning("The model is not an instance of Wav2Vec2ForCTC. " '"lm_head" module is not imported.')
_LG.warning(
"The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.'
)
aux_num_out = None
wav2vec2 = original
imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"]
if is_wavlm:
imported = wavlm_model(**config, aux_num_out=aux_num_out)
else:
imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == "Wav2Vec2ForCTC":
encoder_state_dict = wav2vec2.encoder.state_dict()
if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model
encoder_state_dict = {rename_wavlm_key(x): encoder_state_dict[x] for x in encoder_state_dict.keys()}
imported.encoder.transformer.load_state_dict(encoder_state_dict)
if is_for_ctc:
imported.aux.load_state_dict(original.lm_head.state_dict())
return imported
def rename_wavlm_key(key):
"""Rename weights and biases of linear transformations, since we define them directly in WavLMSelfAttention,
as opposed to nesting them in Linear modules
"""
return (
key.replace("k_proj.weight", "k_proj_weight")
.replace("k_proj.bias", "k_proj_bias")
.replace("q_proj.weight", "q_proj_weight")
.replace("q_proj.bias", "q_proj_bias")
.replace("v_proj.weight", "v_proj_weight")
.replace("v_proj.bias", "v_proj_bias")
.replace("out_proj.weight", "out_proj_weight")
.replace("out_proj.bias", "out_proj_bias")
)
def import_huggingface_model(original: Module) -> Wav2Vec2Model:
"""Builds :class:`Wav2Vec2Model` from the corresponding model object of
`Transformers <https://huggingface.co/transformers/>`_.
......@@ -68,7 +117,11 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model:
"""
_LG.info("Importing model.")
_LG.info("Loading model configuration.")
config = _get_config(original.config)
is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"]
if is_wavlm:
config = _get_config_wavlm(original.config)
else:
config = _get_config(original.config)
_LG.debug(" - config: %s", config)
_LG.info("Building model.")
imported = _build(config, original)
......
"""
The MIT License (MIT)
Copyright (c) Microsoft Corporation
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import math
from typing import Optional, Tuple
import torch
from torch import nn, Tensor
from torch.nn import functional as F
class WavLMSelfAttention(nn.Module):
"""Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`.
Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763
Args:
embed_dim (int): Total dimension of the model.
num_heads (int): The number of heads.
dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``)
bias (bool, optional): If ``True``, add bias to projections for queries and values. (Default: ``True``)
has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding.
Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``)
num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``)
max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``)
gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
has_relative_attention_bias: bool = False,
num_buckets: int = 32,
max_distance: int = 128,
gru_rel_pos: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if has_relative_attention_bias:
self.rel_attn_embed = nn.Embedding(num_buckets, num_heads)
else:
self.rel_attn_embed = None
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
# Define parameters of the linear transoformations. We don't use Linear to avoid problems with quantization.
# See also https://github.com/pytorch/audio/pull/2822#discussion_r1014431878
self.q_proj_weight, self.k_proj_weight, self.v_proj_weight, self.out_proj_weight = [
nn.Parameter(torch.zeros((embed_dim, embed_dim))) for _ in range(4)
]
self.k_proj_bias = nn.Parameter(torch.zeros(embed_dim))
if bias:
self.v_proj_bias, self.q_proj_bias, self.out_proj_bias = [
nn.Parameter(torch.zeros((embed_dim))) for _ in range(3)
]
else:
self.register_parameter("v_proj_bias", None)
self.register_parameter("q_proj_bias", None)
self.register_parameter("out_proj_bias", None)
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.has_position_bias = True
def compute_bias(self, query_length: int, key_length: int) -> Tensor:
"""Compute relative position embeddings for WavLM model.
Args:
query_length (int): Query position can take values between 0 and ``query_length - 1``.
key_length (int): Key position can take values between 0 and ``key_length - 1``.
Returns:
Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings
"""
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # Shape (query_length, key_length)
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1])
return values
def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True):
"""Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM
paper :cite:`chen2022wavlm`.
Args:
relative_positions (Tensor): Relative offsets between query and key positions,
of shape ``(query_length, key_length)``.
bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting
matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set
to zero. (Default ``True``)
Returns:
Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions.
"""
num_buckets = self.num_buckets
max_distance = self.max_distance
# Shape (query_length, key_length)
relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long)
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
return relative_buckets
def forward(
self,
query: Tensor,
key_padding_mask: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``.
key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape
`(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``)
attn_mask: Needs to be ``None``. The argument exists for compatibility with
``EncoderLayer``. (Default: ``None``)
position_bias (Tensor or None, optional): Position bias of shape
``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be
generated in the first layer and then passed from each encoder layer to the next one.
(Default: ``None``)
Returns:
attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``.
position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``.
"""
bsz, seq_len, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert attention_mask is None
if self.rel_attn_embed is not None and position_bias is None:
position_bias = self.compute_bias(seq_len, seq_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, seq_len, seq_len)
attn_mask_rel_pos: Optional[Tensor] = None
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos: # Apply gating on relative position bias
query_layer = query.view(bsz, seq_len, self.num_heads, -1)
query_layer = query_layer.permute(0, 2, 1, 3)
gate_a, gate_b = torch.sigmoid(
self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len))
bias_k = bias_v = None
add_zero_attn = False
# multi_head_attention_forward expects query shape (seq_len, batch_size, embed_dim)
query = query.transpose(0, 1)
concat_bias = torch.cat((self.q_proj_bias, self.k_proj_bias, self.v_proj_bias))
attn_output, _ = F.multi_head_attention_forward(
query,
query,
query,
self.embed_dim,
self.num_heads,
torch.empty([0]),
concat_bias,
bias_k,
bias_v,
add_zero_attn,
self.dropout_module.p,
self.out_proj_weight,
self.out_proj_bias,
self.training,
key_padding_mask,
need_weights=False,
attn_mask=attn_mask_rel_pos,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
)
attn_output = attn_output.transpose(0, 1) # Convert back to batch-first
return attn_output, position_bias
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment