"...git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "f07b697b8dd1fee39c7ca7d0d95d40910c1e724d"
Commit a5664ca9 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add XLS-R models (#2959)

Summary:
XLSR (cross-lingual speech representation) are a set of cross-lingual self-supervised learning models for generating cross-lingual speech representation. It was first proposed in https://arxiv.org/pdf/2006.13979.pdf which is trained on 53 languages (so-called XLSR-53). This PR supports more XLS-R models from https://arxiv.org/pdf/2111.09296.pdf that have more parameters (300M, 1B, 2B) and are trained on 128 languages.

Pull Request resolved: https://github.com/pytorch/audio/pull/2959

Reviewed By: mthrok

Differential Revision: D42397643

Pulled By: nateanl

fbshipit-source-id: 23e8e51a7cde0a226db4f4028db7df8f02b986ce
parent 5dfe0b22
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
"wav2vec2_base", "wav2vec2_base",
"wav2vec2_large", "wav2vec2_large",
"wav2vec2_large_lv60k", "wav2vec2_large_lv60k",
"wav2vec2_xlsr_300m",
"wav2vec2_xlsr_1b",
"wav2vec2_xlsr_2b",
"hubert_base", "hubert_base",
"hubert_large", "hubert_large",
"hubert_xlarge", "hubert_xlarge",
......
...@@ -490,3 +490,9 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop ...@@ -490,3 +490,9 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
eprint = {1510.08484}, eprint = {1510.08484},
note = {arXiv:1510.08484v1} note = {arXiv:1510.08484v1}
} }
@article{babu2021xls,
title={XLS-R: Self-supervised cross-lingual speech representation learning at scale},
author={Babu, Arun and Wang, Changhan and Tjandra, Andros and Lakhotia, Kushal and Xu, Qiantong and Goyal, Naman and Singh, Kritika and von Platen, Patrick and Saraf, Yatharth and Pino, Juan and others},
journal={arXiv preprint arXiv:2111.09296},
year={2021}
}
{
"_name": "wav2vec2",
"extractor_mode": "layer_norm",
"encoder_layers": 48,
"encoder_embed_dim": 1280,
"encoder_ffn_embed_dim": 5120,
"encoder_attention_heads": 16,
"activation_fn": "gelu",
"dropout": 0.0,
"attention_dropout": 0.0,
"activation_dropout": 0.0,
"encoder_layerdrop": 0.0,
"dropout_input": 0.1,
"dropout_features": 0.1,
"final_dim": 1024,
"layer_norm_first": true,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
"conv_bias": true,
"logit_temp": 0.1,
"quantize_targets": true,
"quantize_input": false,
"same_quantizer": false,
"target_glu": false,
"feature_grad_mult": 1.0,
"latent_vars": 320,
"latent_groups": 2,
"latent_dim": 0,
"mask_length": 10,
"mask_prob": 0.65,
"mask_selection": "static",
"mask_other": 0.0,
"no_mask_overlap": false,
"mask_min_space": 1,
"mask_channel_length": 10,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_channel_other": 0.0,
"no_mask_channel_overlap": false,
"mask_channel_min_space": 1,
"num_negatives": 100,
"negatives_from_everywhere": false,
"cross_sample_negatives": 0,
"codebook_negatives": 0,
"conv_pos": 128,
"conv_pos_groups": 16,
"latent_temp": [
2.0,
0.1,
0.999995
]
}
{
"_name": "wav2vec2",
"extractor_mode": "layer_norm",
"encoder_layers": 48,
"encoder_embed_dim": 1920,
"encoder_ffn_embed_dim": 7680,
"encoder_attention_heads": 16,
"activation_fn": "gelu",
"dropout": 0.0,
"attention_dropout": 0.0,
"activation_dropout": 0.0,
"encoder_layerdrop": 0.0,
"dropout_input": 0.1,
"dropout_features": 0.1,
"final_dim": 1024,
"layer_norm_first": true,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
"conv_bias": true,
"logit_temp": 0.1,
"quantize_targets": true,
"quantize_input": false,
"same_quantizer": false,
"target_glu": false,
"feature_grad_mult": 1.0,
"latent_vars": 320,
"latent_groups": 2,
"latent_dim": 0,
"mask_length": 10,
"mask_prob": 0.65,
"mask_selection": "static",
"mask_other": 0.0,
"no_mask_overlap": false,
"mask_min_space": 1,
"mask_channel_length": 10,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_channel_other": 0.0,
"no_mask_channel_overlap": false,
"mask_channel_min_space": 1,
"num_negatives": 100,
"negatives_from_everywhere": false,
"cross_sample_negatives": 0,
"codebook_negatives": 0,
"conv_pos": 128,
"conv_pos_groups": 16,
"latent_temp": [
2.0,
0.1,
0.999995
]
}
{
"_name": "wav2vec2",
"extractor_mode": "layer_norm",
"encoder_layers": 24,
"encoder_embed_dim": 1024,
"encoder_ffn_embed_dim": 4096,
"encoder_attention_heads": 16,
"activation_fn": "gelu",
"dropout": 0.0,
"attention_dropout": 0.0,
"activation_dropout": 0.0,
"encoder_layerdrop": 0.0,
"dropout_input": 0.0,
"dropout_features": 0.0,
"final_dim": 768,
"layer_norm_first": true,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
"conv_bias": true,
"logit_temp": 0.1,
"quantize_targets": true,
"quantize_input": false,
"same_quantizer": false,
"target_glu": false,
"feature_grad_mult": 1.0,
"latent_vars": 320,
"latent_groups": 2,
"latent_dim": 0,
"mask_length": 10,
"mask_prob": 0.65,
"mask_selection": "static",
"mask_other": 0.0,
"no_mask_overlap": false,
"mask_min_space": 1,
"mask_channel_length": 10,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_channel_other": 0.0,
"no_mask_channel_overlap": false,
"mask_channel_min_space": 1,
"num_negatives": 100,
"negatives_from_everywhere": false,
"cross_sample_negatives": 0,
"codebook_negatives": 0,
"conv_pos": 128,
"conv_pos_groups": 16,
"latent_temp": [
2.0,
0.1,
0.999995
]
}
{
"activation_dropout": 0.0,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2Model"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"codevector_dim": 1024,
"contrastive_logits_temperature": 0.1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"diversity_loss_weight": 0.1,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 1280,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.075,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 48,
"num_negatives": 100,
"pad_token_id": 0,
"proj_codevector_dim": 1024,
"torch_dtype": "float32",
"transformers_version": "4.12.0.dev0",
"use_weighted_layer_sum": false
}
{
"activation_dropout": 0.0,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2Model"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"codevector_dim": 1024,
"contrastive_logits_temperature": 0.1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"diversity_loss_weight": 0.1,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 1920,
"initializer_range": 0.02,
"intermediate_size": 7680,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.075,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 48,
"num_negatives": 100,
"pad_token_id": 0,
"proj_codevector_dim": 1024,
"torch_dtype": "float32",
"transformers_version": "4.12.0.dev0"
}
{
"activation_dropout": 0.0,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2Model"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"codevector_dim": 768,
"contrastive_logits_temperature": 0.1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"diversity_loss_weight": 0.1,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.075,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"num_negatives": 100,
"pad_token_id": 0,
"proj_codevector_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.12.0.dev0",
"use_weighted_layer_sum": false
}
...@@ -3,6 +3,7 @@ from .case_utils import ( ...@@ -3,6 +3,7 @@ from .case_utils import (
HttpServerMixin, HttpServerMixin,
is_ffmpeg_available, is_ffmpeg_available,
PytorchTestCase, PytorchTestCase,
skipIfCudaSmallMemory,
skipIfNoCtcDecoder, skipIfNoCtcDecoder,
skipIfNoCuda, skipIfNoCuda,
skipIfNoExec, skipIfNoExec,
...@@ -38,6 +39,7 @@ __all__ = [ ...@@ -38,6 +39,7 @@ __all__ = [
"is_ffmpeg_available", "is_ffmpeg_available",
"skipIfNoCtcDecoder", "skipIfNoCtcDecoder",
"skipIfNoCuda", "skipIfNoCuda",
"skipIfCudaSmallMemory",
"skipIfNoExec", "skipIfNoExec",
"skipIfNoModule", "skipIfNoModule",
"skipIfNoKaldi", "skipIfNoKaldi",
......
...@@ -208,6 +208,13 @@ skipIfNoCuda = _skipIf( ...@@ -208,6 +208,13 @@ skipIfNoCuda = _skipIf(
reason="CUDA is not available.", reason="CUDA is not available.",
key="NO_CUDA", key="NO_CUDA",
) )
# Skip test if CUDA memory is not enough
# TODO: detect the real CUDA memory size and allow call site to configure how much the test needs
skipIfCudaSmallMemory = _skipIf(
"CI" in os.environ and torch.cuda.is_available(), # temporary
reason="CUDA does not have enough memory.",
key="CUDA_SMALL_MEMORY",
)
skipIfNoSox = _skipIf( skipIfNoSox = _skipIf(
not torchaudio._extension._SOX_INITIALIZED, not torchaudio._extension._SOX_INITIALIZED,
reason="Sox features are not available.", reason="Sox features are not available.",
......
...@@ -9,9 +9,12 @@ from torchaudio.models.wav2vec2 import ( ...@@ -9,9 +9,12 @@ from torchaudio.models.wav2vec2 import (
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
wav2vec2_xlsr_1b,
wav2vec2_xlsr_2b,
wav2vec2_xlsr_300m,
) )
from torchaudio.models.wav2vec2.utils import import_fairseq_model from torchaudio.models.wav2vec2.utils import import_fairseq_model
from torchaudio_unittest.common_utils import get_asset_path, skipIfNoModule, TorchaudioTestCase from torchaudio_unittest.common_utils import get_asset_path, skipIfCudaSmallMemory, skipIfNoModule, TorchaudioTestCase
def _load_config(*paths): def _load_config(*paths):
...@@ -31,6 +34,9 @@ WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k") ...@@ -31,6 +34,9 @@ WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k")
HUBERT_BASE = _load_config("hubert_base_ls960") HUBERT_BASE = _load_config("hubert_base_ls960")
HUBERT_LARGE_LL60K = _load_config("hubert_large_ll60k") HUBERT_LARGE_LL60K = _load_config("hubert_large_ll60k")
HUBERT_XLARGE_LL60K = _load_config("hubert_xtralarge_ll60k") HUBERT_XLARGE_LL60K = _load_config("hubert_xtralarge_ll60k")
WAV2VEC2_XLSR_300M = _load_config("xlsr_300m")
WAV2VEC2_XLSR_1B = _load_config("xlsr_1b")
WAV2VEC2_XLSR_2B = _load_config("xlsr_2b")
# Finetuning models # Finetuning models
WAV2VEC2_BASE_960H = _load_config("wav2vec_small_960h") WAV2VEC2_BASE_960H = _load_config("wav2vec_small_960h")
WAV2VEC2_LARGE_960H = _load_config("wav2vec_large_960h") WAV2VEC2_LARGE_960H = _load_config("wav2vec_large_960h")
...@@ -50,6 +56,14 @@ WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand( ...@@ -50,6 +56,14 @@ WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand(
], ],
name_func=_name_func, name_func=_name_func,
) )
XLSR_PRETRAINING_CONFIGS = parameterized.expand(
[
(WAV2VEC2_XLSR_300M, wav2vec2_xlsr_300m),
(WAV2VEC2_XLSR_1B, wav2vec2_xlsr_1b),
(WAV2VEC2_XLSR_2B, wav2vec2_xlsr_2b),
],
name_func=_name_func,
)
HUBERT_PRETRAINING_CONFIGS = parameterized.expand( HUBERT_PRETRAINING_CONFIGS = parameterized.expand(
[ [
(HUBERT_BASE, hubert_base), (HUBERT_BASE, hubert_base),
...@@ -136,6 +150,23 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -136,6 +150,23 @@ class TestFairseqIntegration(TorchaudioTestCase):
for i, (ref, _) in enumerate(refs["layer_results"]): for i, (ref, _) in enumerate(refs["layer_results"]):
self.assertEqual(hyp[i], ref.transpose(0, 1)) self.assertEqual(hyp[i], ref.transpose(0, 1))
@skipIfCudaSmallMemory
@XLSR_PRETRAINING_CONFIGS
def test_import_xlsr_pretraining_model(self, config, factory_func):
"""XLS-R pretraining models from fairseq can be imported and yields the same results"""
batch_size, num_frames = 3, 1024
original = self._get_model(config).eval()
imported = import_fairseq_model(original).eval()
x = torch.randn(batch_size, num_frames)
hyp, _ = imported.extract_features(x)
refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs["layer_results"]):
# There is one element whose difference is over 1e-5 in wav2vec2_xlsr_1b and wav2vec2_xlsr_2b.
atol = 1.0e-05 if factory_func is wav2vec2_xlsr_300m else 1e-4
self.assertEqual(hyp[i], ref.transpose(0, 1), atol=atol, rtol=1.3e-6)
@HUBERT_PRETRAINING_CONFIGS @HUBERT_PRETRAINING_CONFIGS
def test_import_hubert_pretraining_model(self, config, factory_func): def test_import_hubert_pretraining_model(self, config, factory_func):
"""HuBERT pretraining models from fairseq can be imported and yields the same results""" """HuBERT pretraining models from fairseq can be imported and yields the same results"""
...@@ -157,8 +188,7 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -157,8 +188,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
ref, _ = original.extract_features(x, padding_mask=mask, output_layer=1) ref, _ = original.extract_features(x, padding_mask=mask, output_layer=1)
self.assertEqual(hyp[0], ref) self.assertEqual(hyp[0], ref)
@ALL_PRETRAINING_CONFIGS def _test_recreate_pretraining_model(self, config, factory_func):
def test_recreate_pretraining_model(self, config, factory_func):
"""Imported pretraining models can be recreated via a factory function without fairseq.""" """Imported pretraining models can be recreated via a factory function without fairseq."""
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
...@@ -188,6 +218,15 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -188,6 +218,15 @@ class TestFairseqIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
self.assertEqual(ref_lengths, hyp_lengths) self.assertEqual(ref_lengths, hyp_lengths)
@ALL_PRETRAINING_CONFIGS
def test_wav2vec2_recreate_pretraining_model(self, config, factory_func):
self._test_recreate_pretraining_model(config, factory_func)
@skipIfCudaSmallMemory
@XLSR_PRETRAINING_CONFIGS
def test_xlsr_recreate_pretraining_model(self, config, factory_func):
self._test_recreate_pretraining_model(config, factory_func)
@FINETUNING_CONFIGS @FINETUNING_CONFIGS
def test_import_finetuning_model(self, config, _): def test_import_finetuning_model(self, config, _):
"""Fintuned wav2vec2 models from fairseq can be imported and yields the same results""" """Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
......
...@@ -2,9 +2,24 @@ import json ...@@ -2,9 +2,24 @@ import json
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.models.wav2vec2 import wav2vec2_base, wav2vec2_large, wav2vec2_large_lv60k, wavlm_base, wavlm_large from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
wav2vec2_xlsr_1b,
wav2vec2_xlsr_2b,
wav2vec2_xlsr_300m,
wavlm_base,
wavlm_large,
)
from torchaudio.models.wav2vec2.utils import import_huggingface_model from torchaudio.models.wav2vec2.utils import import_huggingface_model
from torchaudio_unittest.common_utils import get_asset_path, skipIfNoModule, TorchaudioTestCase, zip_equal from torchaudio_unittest.common_utils import (
get_asset_path,
skipIfCudaSmallMemory,
skipIfNoModule,
TorchaudioTestCase,
zip_equal,
)
def _load_config(*paths): def _load_config(*paths):
...@@ -24,6 +39,9 @@ HF_LARGE_XLSR_53 = _load_config("wav2vec2-large-xlsr-53") ...@@ -24,6 +39,9 @@ HF_LARGE_XLSR_53 = _load_config("wav2vec2-large-xlsr-53")
HF_BASE_10K_VOXPOPULI = _load_config("wav2vec2-base-10k-voxpopuli") HF_BASE_10K_VOXPOPULI = _load_config("wav2vec2-base-10k-voxpopuli")
HF_BASE_WAVLM = _load_config("wavlm-base") HF_BASE_WAVLM = _load_config("wavlm-base")
HF_LARGE_WAVLM = _load_config("wavlm-large") HF_LARGE_WAVLM = _load_config("wavlm-large")
HF_XLSR_300M = _load_config("wav2vec2-xls-r-300m")
HF_XLSR_1B = _load_config("wav2vec2-xls-r-1b")
HF_XLSR_2B = _load_config("wav2vec2-xls-r-2b")
# Finetuned # Finetuned
HF_BASE_960H = _load_config("wav2vec2-base-960h") HF_BASE_960H = _load_config("wav2vec2-base-960h")
HF_LARGE_960H = _load_config("wav2vec2-large-960h") HF_LARGE_960H = _load_config("wav2vec2-large-960h")
...@@ -42,6 +60,14 @@ PRETRAIN_CONFIGS = parameterized.expand( ...@@ -42,6 +60,14 @@ PRETRAIN_CONFIGS = parameterized.expand(
], ],
name_func=_name_func, name_func=_name_func,
) )
XLSR_PRETRAIN_CONFIGS = parameterized.expand(
[
(HF_XLSR_300M, wav2vec2_xlsr_300m),
(HF_XLSR_1B, wav2vec2_xlsr_1b),
(HF_XLSR_2B, wav2vec2_xlsr_2b),
],
name_func=_name_func,
)
FINETUNE_CONFIGS = parameterized.expand( FINETUNE_CONFIGS = parameterized.expand(
[ [
(HF_BASE_960H, wav2vec2_base), (HF_BASE_960H, wav2vec2_base),
...@@ -156,6 +182,14 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -156,6 +182,14 @@ class TestHFIntegration(TorchaudioTestCase):
imported = import_huggingface_model(original).eval() imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original, imported, config) self._test_import_pretrain(original, imported, config)
@skipIfCudaSmallMemory
@XLSR_PRETRAIN_CONFIGS
def test_import_xlsr_pretrain(self, config, _):
"""XLS-R models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original, imported, config)
@FINETUNE_CONFIGS @FINETUNE_CONFIGS
def test_import_finetune(self, config, _): def test_import_finetune(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results""" """wav2vec2 models from HF transformers can be imported and yields the same results"""
......
...@@ -20,6 +20,9 @@ from .wav2vec2 import ( ...@@ -20,6 +20,9 @@ from .wav2vec2 import (
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
wav2vec2_model, wav2vec2_model,
wav2vec2_xlsr_1b,
wav2vec2_xlsr_2b,
wav2vec2_xlsr_300m,
Wav2Vec2Model, Wav2Vec2Model,
wavlm_base, wavlm_base,
wavlm_large, wavlm_large,
...@@ -50,6 +53,9 @@ __all__ = [ ...@@ -50,6 +53,9 @@ __all__ = [
"hubert_pretrain_base", "hubert_pretrain_base",
"hubert_pretrain_large", "hubert_pretrain_large",
"hubert_pretrain_xlarge", "hubert_pretrain_xlarge",
"wav2vec2_xlsr_300m",
"wav2vec2_xlsr_1b",
"wav2vec2_xlsr_2b",
"Tacotron2", "Tacotron2",
"Conformer", "Conformer",
"Emformer", "Emformer",
......
...@@ -12,6 +12,9 @@ from .model import ( ...@@ -12,6 +12,9 @@ from .model import (
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
wav2vec2_model, wav2vec2_model,
wav2vec2_xlsr_1b,
wav2vec2_xlsr_2b,
wav2vec2_xlsr_300m,
Wav2Vec2Model, Wav2Vec2Model,
wavlm_base, wavlm_base,
wavlm_large, wavlm_large,
...@@ -36,4 +39,7 @@ __all__ = [ ...@@ -36,4 +39,7 @@ __all__ = [
"hubert_pretrain_large", "hubert_pretrain_large",
"hubert_pretrain_xlarge", "hubert_pretrain_xlarge",
"utils", "utils",
"wav2vec2_xlsr_300m",
"wav2vec2_xlsr_1b",
"wav2vec2_xlsr_2b",
] ]
...@@ -1427,3 +1427,153 @@ def wavlm_large( ...@@ -1427,3 +1427,153 @@ def wavlm_large(
encoder_layer_drop=encoder_layer_drop, encoder_layer_drop=encoder_layer_drop,
aux_num_out=aux_num_out, aux_num_out=aux_num_out,
) )
def wav2vec2_xlsr_300m(
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Builds XLS-R model :cite:`babu2021xls` with 300 millions of parameters. The architecture is compatible
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
:class:`~torchaudio.models.Wav2Vec2Model`.
Args:
encoder_projection_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
aux_num_out (int, optional):
See :py:func:`~torchaudio.models.wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=aux_num_out,
)
def wav2vec2_xlsr_1b(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Builds XLS-R model :cite:`babu2021xls` with 1 billion of parameters. The architecture is compatible
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
:class:`~torchaudio.models.Wav2Vec2Model`.
Args:
encoder_projection_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
aux_num_out (int, optional):
See :py:func:`~torchaudio.models.wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=True,
encoder_embed_dim=1280,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=48,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=5120,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=aux_num_out,
)
def wav2vec2_xlsr_2b(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Builds XLS-R model :cite:`babu2021xls` with 2 billions of parameters. The architecture is compatible
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
:class:`~torchaudio.models.Wav2Vec2Model`.
Args:
encoder_projection_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_dropout (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`~torchaudio.models.wav2vec2_model`.
aux_num_out (int, optional):
See :py:func:`~torchaudio.models.wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=True,
encoder_embed_dim=1920,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=48,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=7680,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=aux_num_out,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment