Commit 41d007b4 authored by Grigory Sizov's avatar Grigory Sizov Committed by Facebook GitHub Bot
Browse files

Follow up on WavLM bundles (#2895)

Summary:
Addressed mthrok's comments in https://github.com/pytorch/audio/pull/2833:
- Moved model type from `_params` directly into the bundle definition. For now I defined model type as "WavLM" for WavLM bundles and "Wav2Vec2" for everything else. We can also distinguish between different Wav2Vec2 falvours - Hubert, VoxPopuli etc, but at the moment this won't imply any functional differences, so I didn't do it
- Expanded the title underline to match the title length

Pull Request resolved: https://github.com/pytorch/audio/pull/2895

Reviewed By: nateanl, mthrok

Differential Revision: D41799875

Pulled By: sgrigory

fbshipit-source-id: 0730d4f91ed60e900643bb74d6cccdd7aa5d7b39
parent 88927e84
......@@ -59,6 +59,9 @@ Factory Functions
hdemucs_low
hdemucs_medium
hdemucs_high
wavlm_model
wavlm_base
wavlm_large
Utility Functions
-----------------
......
......@@ -56,7 +56,7 @@ Pretrained Models
wav2vec 2.0 / HuBERT / WavLM - SSL
--------------------------
----------------------------------
Interface
^^^^^^^^^
......
......@@ -12,10 +12,10 @@ class Wav2Vec2Model(Module):
"""Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`.
Note:
To build the model, please use one of the factory functions.
:py:func:`wav2vec2_model`, :py:func:`wav2vec2_base`, :py:func:`wav2vec2_large`,
:py:func:`wav2vec2_large_lv60k`, :py:func:`hubert_base`, :py:func:`hubert_large`,
and :py:func:`hubert_xlarge`.
To build the model, please use one of the factory functions: :py:func:`wav2vec2_model`,
:py:func:`wav2vec2_base`, :py:func:`wav2vec2_large`, :py:func:`wav2vec2_large_lv60k`,
:py:func:`hubert_base`, :py:func:`hubert_large`, :py:func:`hubert_xlarge`,
:py:func:`wavlm_model`, :py:func:`wavlm_base`, and :py:func:`wavlm_large`.
See Also:
* :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning)
......
......@@ -73,6 +73,7 @@ class Wav2Vec2Bundle:
_params: Dict[str, Any]
_sample_rate: float
_normalize_waveform: bool
_model_type: str
@property
def sample_rate(self) -> float:
......@@ -115,8 +116,7 @@ class Wav2Vec2Bundle:
- HUBERT_ASR_XLARGE
- WAVLM_LARGE
"""
model_type = self._params.pop("model_type", None)
if model_type == "WavLM":
if self._model_type == "WavLM":
model = wavlm_model(**self._params)
else:
model = wav2vec2_model(**self._params)
......@@ -245,6 +245,7 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
},
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
WAV2VEC2_BASE.__doc__ = """Wav2vec 2.0 model ("base" architecture),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -289,6 +290,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_BASE_10M.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -335,6 +337,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_BASE_100H.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
......@@ -381,6 +384,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_BASE_960H.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -425,6 +429,7 @@ WAV2VEC2_LARGE = Wav2Vec2Bundle(
},
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
WAV2VEC2_LARGE.__doc__ = """Wav2vec 2.0 model ("large" architecture),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -469,6 +474,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_LARGE_10M.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -515,6 +521,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_LARGE_100H.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -561,6 +568,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_LARGE_960H.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -605,6 +613,7 @@ WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
},
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
WAV2VEC2_LARGE_LV60K.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
......@@ -649,6 +658,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
......@@ -693,6 +703,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
......@@ -738,6 +749,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* :cite:`librilight` dataset, and
......@@ -782,6 +794,7 @@ WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
},
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
WAV2VEC2_XLSR53.__doc__ = """Wav2vec 2.0 model ("base" architecture),
pre-trained on 56,000 hours of unlabeled audio from multiple datasets (
......@@ -829,6 +842,7 @@ HUBERT_BASE = Wav2Vec2Bundle(
},
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
HUBERT_BASE.__doc__ = """HuBERT model ("base" architecture),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -872,6 +886,7 @@ HUBERT_LARGE = Wav2Vec2Bundle(
},
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
HUBERT_LARGE.__doc__ = """HuBERT model ("large" architecture),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
......@@ -915,6 +930,7 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
},
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
HUBERT_XLARGE.__doc__ = """HuBERT model ("extra large" architecture),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
......@@ -959,6 +975,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
HUBERT_ASR_LARGE.__doc__ = """HuBERT model ("large" architecture),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
......@@ -1004,6 +1021,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
HUBERT_ASR_XLARGE.__doc__ = """HuBERT model ("extra large" architecture),
pre-trained on 60,000 hours of unlabeled audio from
......@@ -1053,6 +1071,7 @@ VOXPOPULI_ASR_BASE_10K_DE = Wav2Vec2ASRBundle(
_sample_rate=16000,
_normalize_waveform=False,
_remove_aux_axis=(1, 2, 3, 35),
_model_type="Wav2Vec2",
)
VOXPOPULI_ASR_BASE_10K_DE.__doc__ = """wav2vec 2.0 model ("base" architecture),
pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
......@@ -1100,6 +1119,7 @@ VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle(
_sample_rate=16000,
_normalize_waveform=False,
_remove_aux_axis=(1, 2, 3, 31),
_model_type="Wav2Vec2",
)
VOXPOPULI_ASR_BASE_10K_EN.__doc__ = """wav2vec 2.0 model ("base" architecture),
pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
......@@ -1147,6 +1167,7 @@ VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle(
_sample_rate=16000,
_normalize_waveform=False,
_remove_aux_axis=(1, 2, 3, 35),
_model_type="Wav2Vec2",
)
VOXPOPULI_ASR_BASE_10K_ES.__doc__ = """wav2vec 2.0 model ("base" architecture),
pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
......@@ -1192,6 +1213,7 @@ VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle(
_labels=utils._get_fr_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_model_type="Wav2Vec2",
)
VOXPOPULI_ASR_BASE_10K_FR.__doc__ = """wav2vec 2.0 model ("base" architecture),
pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
......@@ -1239,6 +1261,7 @@ VOXPOPULI_ASR_BASE_10K_IT = Wav2Vec2ASRBundle(
_sample_rate=16000,
_normalize_waveform=False,
_remove_aux_axis=(1, 2, 3),
_model_type="Wav2Vec2",
)
VOXPOPULI_ASR_BASE_10K_IT.__doc__ = """wav2vec 2.0 model ("base" architecture),
pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
......@@ -1283,8 +1306,8 @@ WAVLM_BASE = Wav2Vec2Bundle(
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": None,
"model_type": "WavLM",
},
_model_type="WavLM",
_sample_rate=16000,
_normalize_waveform=False,
)
......@@ -1294,7 +1317,7 @@ pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`71
Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
redistributed with the same license.
[`License <https://github.com/microsoft/unilm/blob/65f15af2a307ebb64cfb25adf54375b002e6fe8d/LICENSE>`__,
`Source https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
`Source <https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
......@@ -1329,8 +1352,8 @@ WAVLM_BASE_PLUS = Wav2Vec2Bundle(
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": None,
"model_type": "WavLM",
},
_model_type="WavLM",
_sample_rate=16000,
_normalize_waveform=False,
)
......@@ -1341,7 +1364,7 @@ and 24,000 hours of *VoxPopuli* :cite:`voxpopuli`, not fine-tuned.
Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
redistributed with the same license.
[`License <https://github.com/microsoft/unilm/blob/65f15af2a307ebb64cfb25adf54375b002e6fe8d/LICENSE>`__,
`Source https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
`Source <https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
......@@ -1376,8 +1399,8 @@ WAVLM_LARGE = Wav2Vec2Bundle(
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": None,
"model_type": "WavLM",
},
_model_type="WavLM",
_sample_rate=16000,
_normalize_waveform=True,
)
......@@ -1388,7 +1411,7 @@ and 24,000 hours of *VoxPopuli* :cite:`voxpopuli`, not fine-tuned.
Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
redistributed with the same license.
[`License <https://github.com/microsoft/unilm/blob/65f15af2a307ebb64cfb25adf54375b002e6fe8d/LICENSE>`__,
`Source https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
`Source <https://github.com/microsoft/unilm/tree/65f15af2a307ebb64cfb25adf54375b002e6fe8d/wavlm#pre-trained-models>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment