Commit 26f62dc5 authored by Grigory Sizov's avatar Grigory Sizov Committed by Facebook GitHub Bot
Browse files

Add WavLM bundles (#2833)

Summary:
Closes T136364380, follow-up to https://github.com/pytorch/audio/issues/2822

- Added "base", "base+", and "large" bundles for WavLM
- Expanded `wav2vec2_pipeline_test.py` to include the new bundles
- Added the new bundles to docs in `pipelines.rst`

Pull Request resolved: https://github.com/pytorch/audio/pull/2833

Reviewed By: nateanl

Differential Revision: D41194796

Pulled By: sgrigory

fbshipit-source-id: bf8e96c05b6a81ac5c5a014c46adeeac12685328
parent 2d1da45c
...@@ -55,7 +55,7 @@ Pretrained Models ...@@ -55,7 +55,7 @@ Pretrained Models
EMFORMER_RNNT_BASE_LIBRISPEECH EMFORMER_RNNT_BASE_LIBRISPEECH
wav2vec 2.0 / HuBERT - SSL wav2vec 2.0 / HuBERT / WavLM - SSL
-------------------------- --------------------------
Interface Interface
...@@ -87,6 +87,9 @@ Pretrained Models ...@@ -87,6 +87,9 @@ Pretrained Models
HUBERT_BASE HUBERT_BASE
HUBERT_LARGE HUBERT_LARGE
HUBERT_XLARGE HUBERT_XLARGE
WAVLM_BASE
WAVLM_BASE_PLUS
WAVLM_LARGE
wav2vec 2.0 / HuBERT - Fine-tuned ASR wav2vec 2.0 / HuBERT - Fine-tuned ASR
------------------------------------- -------------------------------------
......
...@@ -24,6 +24,9 @@ from torchaudio.pipelines import ( ...@@ -24,6 +24,9 @@ from torchaudio.pipelines import (
WAV2VEC2_LARGE, WAV2VEC2_LARGE,
WAV2VEC2_LARGE_LV60K, WAV2VEC2_LARGE_LV60K,
WAV2VEC2_XLSR53, WAV2VEC2_XLSR53,
WAVLM_BASE,
WAVLM_BASE_PLUS,
WAVLM_LARGE,
) )
...@@ -37,6 +40,9 @@ from torchaudio.pipelines import ( ...@@ -37,6 +40,9 @@ from torchaudio.pipelines import (
HUBERT_BASE, HUBERT_BASE,
HUBERT_LARGE, HUBERT_LARGE,
HUBERT_XLARGE, HUBERT_XLARGE,
WAVLM_BASE,
WAVLM_BASE_PLUS,
WAVLM_LARGE,
], ],
) )
def test_pretraining_models(bundle): def test_pretraining_models(bundle):
......
...@@ -37,6 +37,9 @@ from ._wav2vec2.impl import ( ...@@ -37,6 +37,9 @@ from ._wav2vec2.impl import (
WAV2VEC2_XLSR53, WAV2VEC2_XLSR53,
Wav2Vec2ASRBundle, Wav2Vec2ASRBundle,
Wav2Vec2Bundle, Wav2Vec2Bundle,
WAVLM_BASE,
WAVLM_BASE_PLUS,
WAVLM_LARGE,
) )
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
...@@ -67,6 +70,9 @@ __all__ = [ ...@@ -67,6 +70,9 @@ __all__ = [
"HUBERT_XLARGE", "HUBERT_XLARGE",
"HUBERT_ASR_LARGE", "HUBERT_ASR_LARGE",
"HUBERT_ASR_XLARGE", "HUBERT_ASR_XLARGE",
"WAVLM_BASE",
"WAVLM_BASE_PLUS",
"WAVLM_LARGE",
"Tacotron2TTSBundle", "Tacotron2TTSBundle",
"TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH", "TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
"TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH", "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
......
...@@ -3,7 +3,7 @@ from typing import Any, Dict, Tuple ...@@ -3,7 +3,7 @@ from typing import Any, Dict, Tuple
import torch import torch
from torchaudio._internal import load_state_dict_from_url from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
from . import utils from . import utils
...@@ -69,6 +69,10 @@ class Wav2Vec2Bundle: ...@@ -69,6 +69,10 @@ class Wav2Vec2Bundle:
Args: Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
""" """
model_type = self._params.pop("model_type", None)
if model_type == "WavLM":
model = wavlm_model(**self._params)
else:
model = wav2vec2_model(**self._params) model = wav2vec2_model(**self._params)
model.load_state_dict(self._get_state_dict(dl_kwargs)) model.load_state_dict(self._get_state_dict(dl_kwargs))
model.eval() model.eval()
...@@ -1177,3 +1181,140 @@ redistributed with the same license. ...@@ -1177,3 +1181,140 @@ redistributed with the same license.
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAVLM_BASE = Wav2Vec2Bundle(
"wavlm_base.pth",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
"extractor_conv_bias": False,
"encoder_embed_dim": 768,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 12,
"encoder_num_heads": 12,
"encoder_max_distance": 800,
"encoder_num_buckets": 320,
"encoder_attention_dropout": 0.1,
"encoder_ff_interm_features": 3072,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": None,
"model_type": "WavLM",
},
_sample_rate=16000,
)
WAVLM_BASE.__doc__ = """WavLM Base model ("base" architecture),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`, not fine-tuned.
Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
redistributed with the same license.
[`License <https://github.com/microsoft/unilm/blob/65f15af2a307ebb64cfb25adf54375b002e6fe8d/LICENSE>`__,
`Source https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
WAVLM_BASE_PLUS = Wav2Vec2Bundle(
"wavlm_base_plus.pth",
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
"extractor_conv_bias": False,
"encoder_embed_dim": 768,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 12,
"encoder_num_heads": 12,
"encoder_max_distance": 800,
"encoder_num_buckets": 320,
"encoder_attention_dropout": 0.1,
"encoder_ff_interm_features": 3072,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": None,
"model_type": "WavLM",
},
_sample_rate=16000,
)
WAVLM_BASE_PLUS.__doc__ = """WavLM Base+ model ("base" architecture),
pre-trained on 60,000 hours of Libri-Light dataset :cite:`librilight`, 10,000 hours of GigaSpeech :cite:`GigaSpeech2021`,
and 24,000 hours of *VoxPopuli* :cite:`voxpopuli`, not fine-tuned.
Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
redistributed with the same license.
[`License <https://github.com/microsoft/unilm/blob/65f15af2a307ebb64cfb25adf54375b002e6fe8d/LICENSE>`__,
`Source https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
WAVLM_LARGE = Wav2Vec2Bundle(
"wavlm_large.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
"extractor_conv_bias": False,
"encoder_embed_dim": 1024,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 24,
"encoder_num_heads": 16,
"encoder_max_distance": 800,
"encoder_num_buckets": 320,
"encoder_attention_dropout": 0.1,
"encoder_ff_interm_features": 4096,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": None,
"model_type": "WavLM",
},
_sample_rate=16000,
)
WAVLM_LARGE.__doc__ = """WavLM Large model ("large" architecture),
pre-trained on 60,000 hours of Libri-Light dataset :cite:`librilight`, 10,000 hours of GigaSpeech :cite:`GigaSpeech2021`,
and 24,000 hours of *VoxPopuli* :cite:`voxpopuli`, not fine-tuned.
Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
redistributed with the same license.
[`License <https://github.com/microsoft/unilm/blob/65f15af2a307ebb64cfb25adf54375b002e6fe8d/LICENSE>`__,
`Source https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment