Commit 9b7b64e4 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add pre-trained pipelines for XLS-R models (#2978)

Summary:
The PR adds three `Wav2Vec2Bundle ` pipeline objects for XLS-R models:
- WAV2VEC2_XLSR_300M
- WAV2VEC2_XLSR_1B
- WAV2VEC2_XLSR_2B

All three models use layer normalization in the feature extraction layers, hence `_normalize_waveform` is set to `True`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2978

Reviewed By: hwangjeff

Differential Revision: D42501491

Pulled By: nateanl

fbshipit-source-id: 2429ec880cc14798034843381e458e1b4664dac3
parent 82ded7e7
...@@ -84,6 +84,9 @@ Pretrained Models ...@@ -84,6 +84,9 @@ Pretrained Models
WAV2VEC2_LARGE WAV2VEC2_LARGE
WAV2VEC2_LARGE_LV60K WAV2VEC2_LARGE_LV60K
WAV2VEC2_XLSR53 WAV2VEC2_XLSR53
WAV2VEC2_XLSR_300M
WAV2VEC2_XLSR_1B
WAV2VEC2_XLSR_2B
HUBERT_BASE HUBERT_BASE
HUBERT_LARGE HUBERT_LARGE
HUBERT_XLARGE HUBERT_XLARGE
......
...@@ -496,3 +496,11 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop ...@@ -496,3 +496,11 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
journal={arXiv preprint arXiv:2111.09296}, journal={arXiv preprint arXiv:2111.09296},
year={2021} year={2021}
} }
@inproceedings{valk2021voxlingua107,
title={VoxLingua107: a dataset for spoken language recognition},
author={Valk, J{\"o}rgen and Alum{\"a}e, Tanel},
booktitle={2021 IEEE Spoken Language Technology Workshop (SLT)},
pages={652--658},
year={2021},
organization={IEEE}
}
import os
import pytest import pytest
import torchaudio import torchaudio
from torchaudio.pipelines import ( from torchaudio.pipelines import (
...@@ -24,6 +26,8 @@ from torchaudio.pipelines import ( ...@@ -24,6 +26,8 @@ from torchaudio.pipelines import (
WAV2VEC2_LARGE, WAV2VEC2_LARGE,
WAV2VEC2_LARGE_LV60K, WAV2VEC2_LARGE_LV60K,
WAV2VEC2_XLSR53, WAV2VEC2_XLSR53,
WAV2VEC2_XLSR_1B,
WAV2VEC2_XLSR_300M,
WAVLM_BASE, WAVLM_BASE,
WAVLM_BASE_PLUS, WAVLM_BASE_PLUS,
WAVLM_LARGE, WAVLM_LARGE,
...@@ -50,6 +54,19 @@ def test_pretraining_models(bundle): ...@@ -50,6 +54,19 @@ def test_pretraining_models(bundle):
bundle.get_model() bundle.get_model()
@pytest.mark.skipif("CI" not in os.environ, reason="Run tests only in CI environment.")
@pytest.mark.parametrize(
"bundle",
[
WAV2VEC2_XLSR_300M,
WAV2VEC2_XLSR_1B,
],
)
def test_xlsr_pretraining_models(bundle):
"""Smoke test of downloading weights for pretraining models"""
bundle.get_model()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"bundle,lang,expected", "bundle,lang,expected",
[ [
......
...@@ -35,6 +35,9 @@ from ._wav2vec2.impl import ( ...@@ -35,6 +35,9 @@ from ._wav2vec2.impl import (
WAV2VEC2_LARGE, WAV2VEC2_LARGE,
WAV2VEC2_LARGE_LV60K, WAV2VEC2_LARGE_LV60K,
WAV2VEC2_XLSR53, WAV2VEC2_XLSR53,
WAV2VEC2_XLSR_1B,
WAV2VEC2_XLSR_2B,
WAV2VEC2_XLSR_300M,
Wav2Vec2ASRBundle, Wav2Vec2ASRBundle,
Wav2Vec2Bundle, Wav2Vec2Bundle,
WAVLM_BASE, WAVLM_BASE,
...@@ -60,6 +63,9 @@ __all__ = [ ...@@ -60,6 +63,9 @@ __all__ = [
"WAV2VEC2_ASR_LARGE_LV60K_100H", "WAV2VEC2_ASR_LARGE_LV60K_100H",
"WAV2VEC2_ASR_LARGE_LV60K_960H", "WAV2VEC2_ASR_LARGE_LV60K_960H",
"WAV2VEC2_XLSR53", "WAV2VEC2_XLSR53",
"WAV2VEC2_XLSR_300M",
"WAV2VEC2_XLSR_1B",
"WAV2VEC2_XLSR_2B",
"VOXPOPULI_ASR_BASE_10K_EN", "VOXPOPULI_ASR_BASE_10K_EN",
"VOXPOPULI_ASR_BASE_10K_ES", "VOXPOPULI_ASR_BASE_10K_ES",
"VOXPOPULI_ASR_BASE_10K_DE", "VOXPOPULI_ASR_BASE_10K_DE",
......
...@@ -110,6 +110,9 @@ class Wav2Vec2Bundle: ...@@ -110,6 +110,9 @@ class Wav2Vec2Bundle:
- WAV2VEC2_ASR_LARGE_LV60K_100H - WAV2VEC2_ASR_LARGE_LV60K_100H
- WAV2VEC2_ASR_LARGE_LV60K_960H - WAV2VEC2_ASR_LARGE_LV60K_960H
- WAV2VEC2_XLSR53 - WAV2VEC2_XLSR53
- WAV2VEC2_XLSR_300M
- WAV2VEC2_XLSR_1B
- WAV2VEC2_XLSR_2B
- HUBERT_LARGE - HUBERT_LARGE
- HUBERT_XLARGE - HUBERT_XLARGE
- HUBERT_ASR_LARGE - HUBERT_ASR_LARGE
...@@ -1415,3 +1418,152 @@ redistributed with the same license. ...@@ -1415,3 +1418,152 @@ redistributed with the same license.
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
""" # noqa: E501 """ # noqa: E501
WAV2VEC2_XLSR_300M = Wav2Vec2Bundle(
"wav2vec2_xlsr_300m.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
"extractor_conv_bias": True,
"encoder_embed_dim": 1024,
"encoder_projection_dropout": 0.0,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 24,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 4096,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": None,
},
_model_type="Wav2Vec2",
_sample_rate=16000,
_normalize_waveform=True,
)
WAV2VEC2_XLSR_300M.__doc__ = """XLS-R model with 300 million parameters,
pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
*Multilingual LibriSpeech* :cite:`Pratap_2020`,
*CommonVoice* :cite:`ardila2020common`,
*VoxLingua107* :cite:`valk2021voxlingua107`,
*BABEL* :cite:`Gales2014SpeechRA`, and
*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
not fine-tuned.
Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
redistributed with the same license.
[`License <https://github.com/facebookresearch/fairseq/blob/30c912b73c0f88d41171879b2f03226a171004ef/LICENSE>`__,
`Source <https://github.com/facebookresearch/fairseq/tree/30c912b73c0f88d41171879b2f03226a171004ef/examples/wav2vec/xlsr#xls-r>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
""" # noqa: E501
WAV2VEC2_XLSR_1B = Wav2Vec2Bundle(
"wav2vec2_xlsr_1b.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
"extractor_conv_bias": True,
"encoder_embed_dim": 1280,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 48,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 5120,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": None,
},
_model_type="Wav2Vec2",
_sample_rate=16000,
_normalize_waveform=True,
)
WAV2VEC2_XLSR_1B.__doc__ = """XLS-R model with 1 billion parameters,
pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
*Multilingual LibriSpeech* :cite:`Pratap_2020`,
*CommonVoice* :cite:`ardila2020common`,
*VoxLingua107* :cite:`valk2021voxlingua107`,
*BABEL* :cite:`Gales2014SpeechRA`, and
*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
not fine-tuned.
Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
redistributed with the same license.
[`License <https://github.com/facebookresearch/fairseq/blob/30c912b73c0f88d41171879b2f03226a171004ef/LICENSE>`__,
`Source <https://github.com/facebookresearch/fairseq/tree/30c912b73c0f88d41171879b2f03226a171004ef/examples/wav2vec/xlsr#xls-r>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
""" # noqa: E501
WAV2VEC2_XLSR_2B = Wav2Vec2Bundle(
"wav2vec2_xlsr_2b.pth",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
"extractor_conv_bias": True,
"encoder_embed_dim": 1920,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 48,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 7680,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": None,
},
_model_type="Wav2Vec2",
_sample_rate=16000,
_normalize_waveform=True,
)
WAV2VEC2_XLSR_2B.__doc__ = """XLS-R model with 2 billion parameters,
pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
*Multilingual LibriSpeech* :cite:`Pratap_2020`,
*CommonVoice* :cite:`ardila2020common`,
*VoxLingua107* :cite:`valk2021voxlingua107`,
*BABEL* :cite:`Gales2014SpeechRA`, and
*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
not fine-tuned.
Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
redistributed with the same license.
[`License <https://github.com/facebookresearch/fairseq/blob/30c912b73c0f88d41171879b2f03226a171004ef/LICENSE>`__,
`Source <https://github.com/facebookresearch/fairseq/tree/30c912b73c0f88d41171879b2f03226a171004ef/examples/wav2vec/xlsr#xls-r>`__]
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
""" # noqa: E501
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment