Unverified Commit 274ada80 authored by moto's avatar moto Committed by GitHub
Browse files

Merge factory functions of pre-training model and fine-tuned model (#1830)

This commit merges wav2vec2/hubert factory functions for pre-training and fine-tuning. In #1829, we added parameters to customize the models that are not part of architecture, and `aux_num_out` falls into this category, so it is no longer necessary to have separate functions. This concludes the wav2vec2/HuBERT API update in release 0.10.

The summary of BC-breaking changes on wav2vec2 APIs between 0.9 and 0.10 (when this commit is incorporated)
1. `Wav2Vec2Model.extract_features`
In 0.9, it was returning the output from `FeatureExtractor` module. In 0.10, it returns the list of outputs from the intermediate layers of `TransformerEncoder` block.
2. `wav2vec2_base(num_out: int)` -> `wav2vec2_base(<dropout_params:float>, aux_num_out: Optional[int]=None)`
    - `num_out` was renamed to `aux_num_out` and optional. If it is omitted, the resulting model does not have the linear layer for fine-tuning.
    - Added dropout parameters.
parent 60aeb78a
......@@ -96,21 +96,6 @@ wav2vec2_large_lv60k
.. autofunction:: wav2vec2_large_lv60k
wav2vec2_ft_base
^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_ft_base
wav2vec2_ft_large
^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_ft_large
wav2vec2_ft_large_lv60k
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_ft_large_lv60k
hubert_base
^^^^^^^^^^^
......@@ -126,16 +111,6 @@ hubert_xlarge
.. autofunction:: hubert_xlarge
hubert_ft_large
^^^^^^^^^^^^^^^^
.. autofunction:: hubert_ft_large
hubert_ft_xlarge
^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_ft_xlarge
Pre-trained Models
------------------
......
......@@ -2,17 +2,12 @@ import json
import torch
from torchaudio.models.wav2vec2 import (
wav2vec2_ft_base,
wav2vec2_ft_large,
wav2vec2_ft_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
hubert_base,
hubert_large,
hubert_xlarge,
hubert_ft_large,
hubert_ft_xlarge,
)
from torchaudio.models.wav2vec2.utils import (
import_fairseq_model,
......@@ -74,12 +69,12 @@ ALL_PRETRAINING_CONFIGS = parameterized.expand([
(HUBERT_XLARGE_LL60K, hubert_xlarge),
], name_func=_name_func)
FINETUNING_CONFIGS = parameterized.expand([
(WAV2VEC2_BASE_960H, wav2vec2_ft_base),
(WAV2VEC2_LARGE_960H, wav2vec2_ft_large),
(WAV2VEC2_LARGE_LV60K_960H, wav2vec2_ft_large_lv60k),
(WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_ft_large_lv60k),
(HUBERT_LARGE, hubert_ft_large),
(HUBERT_XLARGE, hubert_ft_xlarge),
(WAV2VEC2_BASE_960H, wav2vec2_base),
(WAV2VEC2_LARGE_960H, wav2vec2_large),
(WAV2VEC2_LARGE_LV60K_960H, wav2vec2_large_lv60k),
(WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k),
(HUBERT_LARGE, hubert_large),
(HUBERT_XLARGE, hubert_xlarge),
], name_func=_name_func)
......
......@@ -2,9 +2,6 @@ import json
import torch
from torchaudio.models.wav2vec2 import (
wav2vec2_ft_base,
wav2vec2_ft_large,
wav2vec2_ft_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
......@@ -50,11 +47,11 @@ PRETRAIN_CONFIGS = parameterized.expand([
(HF_BASE_10K_VOXPOPULI, wav2vec2_base),
], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([
(HF_BASE_960H, wav2vec2_ft_base),
(HF_LARGE_960H, wav2vec2_ft_large),
(HF_LARGE_LV60_960H, wav2vec2_ft_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_ft_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_ft_large_lv60k),
(HF_BASE_960H, wav2vec2_base),
(HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_large_lv60k),
], name_func=_name_func)
......
......@@ -5,17 +5,12 @@ import torch.nn.functional as F
from typing import Tuple
from torchaudio.models.wav2vec2 import (
wav2vec2_ft_base,
wav2vec2_ft_large,
wav2vec2_ft_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
hubert_base,
hubert_large,
hubert_xlarge,
hubert_ft_large,
hubert_ft_xlarge,
)
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
......@@ -36,7 +31,7 @@ def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
pretrain_factory_funcs = parameterized.expand([
factory_funcs = parameterized.expand([
(wav2vec2_base, ),
(wav2vec2_large, ),
(wav2vec2_large_lv60k, ),
......@@ -46,15 +41,6 @@ pretrain_factory_funcs = parameterized.expand([
], name_func=_name_func)
finetune_factory_funcs = parameterized.expand([
(wav2vec2_ft_base, ),
(wav2vec2_ft_large, ),
(wav2vec2_ft_large_lv60k, ),
(hubert_ft_large, ),
(hubert_ft_xlarge, ),
], name_func=_name_func)
class TestWav2Vec2Model(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
......@@ -74,7 +60,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cpu_smoke_test(self, dtype):
model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype)
model = wav2vec2_ft_base(aux_num_out=32)
model = wav2vec2_base(aux_num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
......@@ -82,7 +68,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cuda_smoke_test(self, dtype):
model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype)
model = wav2vec2_ft_base(aux_num_out=32)
model = wav2vec2_base(aux_num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype)
def _feature_extractor_test(self, model):
......@@ -112,13 +98,8 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(all_features[i], features[i])
assert lengths_.shape == torch.Size([batch_size])
@pretrain_factory_funcs
def test_pretrain_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail"""
self._feature_extractor_test(factory_func())
@finetune_factory_funcs
def test_finetune_feature_extractor_test(self, factory_func):
@factory_funcs
def test_extract_feature(self, factory_func):
"""`extract_features` method does not fail"""
self._feature_extractor_test(factory_func(aux_num_out=32))
......@@ -142,17 +123,17 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# We allow max atol=0.005 -> 0.5%
self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0)
@pretrain_factory_funcs
@factory_funcs
def test_pretrain_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close
"""
self._test_batch_consistency(factory_func())
@pretrain_factory_funcs
@factory_funcs
def test_finetune_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close
"""
self._test_batch_consistency(factory_func())
self._test_batch_consistency(factory_func(aux_num_out=32))
def _test_zero_length(self, model):
model.eval()
......@@ -165,12 +146,12 @@ class TestWav2Vec2Model(TorchaudioTestCase):
_, output_lengths = model.extract_features(waveforms, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
@pretrain_factory_funcs
@factory_funcs
def test_pretrain_zero_length(self, factory_func):
"""Passing zero length should not fail"""
self._test_zero_length(factory_func())
@finetune_factory_funcs
@factory_funcs
def test_finetune_zero_length(self, factory_func):
"""Passing zero length should not fail"""
self._test_zero_length(factory_func(aux_num_out=32))
......@@ -193,7 +174,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@pretrain_factory_funcs
@factory_funcs
def test_pretrain_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
if factory_func is hubert_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true':
......@@ -202,12 +183,12 @@ class TestWav2Vec2Model(TorchaudioTestCase):
'See https://github.com/pytorch/pytorch/issues/65776')
self._test_torchscript(factory_func())
@finetune_factory_funcs
@factory_funcs
def test_finetune_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
if factory_func is hubert_ft_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true':
if factory_func is hubert_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true':
self.skipTest(
'hubert_ft_xlarge is known to fail on Windows CI. '
'hubert_xlarge is known to fail on Windows CI. '
'See https://github.com/pytorch/pytorch/issues/65776')
self._test_torchscript(factory_func(aux_num_out=32))
......@@ -229,15 +210,9 @@ class TestWav2Vec2Model(TorchaudioTestCase):
_, _ = quantized(waveforms, lengths)
@pretrain_factory_funcs
@skipIfNoQengine
def test_pretrain_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
self._test_quantize_smoke_test(factory_func())
@finetune_factory_funcs
@factory_funcs
@skipIfNoQengine
def test_finetune_quantize(self, factory_func):
def test_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
self._test_quantize_smoke_test(factory_func(aux_num_out=32))
......@@ -268,14 +243,8 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@pretrain_factory_funcs
@skipIfNoQengine
def test_pretrain_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func())
@finetune_factory_funcs
@factory_funcs
@skipIfNoQengine
def test_finetune_quantize_torchscript(self, factory_func):
def test_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func(aux_num_out=32))
......@@ -6,17 +6,12 @@ from .tacotron2 import Tacotron2, tacotron2
from .wav2vec2 import (
Wav2Vec2Model,
wav2vec2_model,
wav2vec2_ft_base,
wav2vec2_ft_large,
wav2vec2_ft_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
hubert_base,
hubert_large,
hubert_xlarge,
hubert_ft_large,
hubert_ft_xlarge,
)
from .wav2vec2.pretrained import (
Wav2Vec2PretrainedModelBundle,
......@@ -48,17 +43,12 @@ __all__ = [
'DeepSpeech',
'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_ft_base',
'wav2vec2_ft_large',
'wav2vec2_ft_large_lv60k',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
'hubert_base',
'hubert_large',
'hubert_xlarge',
'hubert_ft_large',
'hubert_ft_xlarge',
'Wav2Vec2PretrainedModelBundle',
'WAV2VEC2_BASE',
'WAV2VEC2_LARGE',
......
from .model import (
Wav2Vec2Model,
wav2vec2_model,
wav2vec2_ft_base,
wav2vec2_ft_large,
wav2vec2_ft_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
hubert_base,
hubert_large,
hubert_xlarge,
hubert_ft_large,
hubert_ft_xlarge,
)
from . import utils
__all__ = [
'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_ft_base',
'wav2vec2_ft_large',
'wav2vec2_ft_large_lv60k',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
'hubert_base',
'hubert_large',
'hubert_xlarge',
'hubert_ft_large',
'hubert_ft_xlarge',
'utils',
]
......@@ -271,11 +271,9 @@ def wav2vec2_base(
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build wav2vec2 model with "base" configuration
This is one of the model architecture used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
"""Build Wav2Vec2Model with "base" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
Args:
encoder_projection_dropout (float):
......@@ -288,59 +286,12 @@ def wav2vec2_base(
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
"""
return wav2vec2_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=None,
)
def wav2vec2_ft_base(
aux_num_out: int,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "base" wav2vec2 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args:
aux_num_out (int):
The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
aux_num_out (int or None, optional):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode="group_norm",
......@@ -368,11 +319,9 @@ def wav2vec2_large(
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build wav2vec2 model with "large" configuration
This is one of the model architecture used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
"""Build Wav2Vec2Model with "large" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
Args:
encoder_projection_dropout (float):
......@@ -385,59 +334,12 @@ def wav2vec2_large(
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
"""
return wav2vec2_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=None,
)
def wav2vec2_ft_large(
aux_num_out: int,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "large" wav2vec2.0 model with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args:
aux_num_out (int):
The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
aux_num_out (int or None, optional):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode="group_norm",
......@@ -465,11 +367,9 @@ def wav2vec2_large_lv60k(
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Large LV-60k" configuration
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
"""Build Wav2Vec2Model with "large lv-60k" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
Args:
encoder_projection_dropout (float):
......@@ -482,60 +382,12 @@ def wav2vec2_large_lv60k(
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
"""
return wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=None,
)
def wav2vec2_ft_large_lv60k(
aux_num_out: int,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "Large LV-60k" wav2vec2.0 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args:
aux_num_out (int):
The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
aux_num_out (int or None, optional):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model: The resulting model.
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode="layer_norm",
......@@ -563,11 +415,9 @@ def hubert_base(
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.05,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build HuBERT model with "Base" configuration
This is one of the model architectures used in *HuBERT*
[:footcite:`hsu2021hubert`] for pretraining.
"""Build HuBERT model with "base" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Args:
encoder_projection_dropout (float):
......@@ -580,9 +430,12 @@ def hubert_base(
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
aux_num_out (int or None, optional):
See :py:func:`wav2vec2_model`.
Returns:
HuBERT: The resulting model.
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode='group_norm',
......@@ -600,7 +453,7 @@ def hubert_base(
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=None,
aux_num_out=aux_num_out,
)
......@@ -610,11 +463,9 @@ def hubert_large(
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build HuBERT model with "Large" configuration
This is one of the model architectures used in *HuBERT*
[:footcite:`hsu2021hubert`] for pretraining.
"""Build HuBERT model with "large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Args:
encoder_projection_dropout (float):
......@@ -627,60 +478,12 @@ def hubert_large(
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns:
HuBERT: The resulting model.
"""
return wav2vec2_model(
extractor_mode='layer_norm',
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=None,
)
def hubert_ft_large(
aux_num_out: int,
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "Large" HuBERT model with an extra linear module
This is one of the model architecture used in *HuBERT*
[:footcite:`hsu2021hubert`] for fine-tuning for ASR task.
Args:
aux_num_out (int):
The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
aux_num_out (int or None, optional):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode='layer_norm',
......@@ -708,11 +511,9 @@ def hubert_xlarge(
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build HuBERT model with "extra large" configuration
This is one of the model architectures used in *HuBERT*
[:footcite:`hsu2021hubert`] for pretraining.
"""Build HuBERT model with "extra large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Args:
encoder_projection_dropout (float):
......@@ -725,59 +526,12 @@ def hubert_xlarge(
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns:
HuBERT: The resulting model.
"""
return wav2vec2_model(
extractor_mode='layer_norm',
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1280,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=48,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=5120,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
aux_num_out=None,
)
def hubert_ft_xlarge(
aux_num_out: int,
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "extra large" HuBERT model with an extra linear module
This is one of the model architecture used in *HuBERT*
[:footcite:`hsu2021hubert`] for fine-tuning for ASR task.
Args:
aux_num_out (int):
The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
aux_num_out (int or None, optional):
See :py:func:`wav2vec2_model`.
Returns:
Wav2Vec2Model: The resulting model.
Wav2Vec2Model:
The resulting model.
"""
return wav2vec2_model(
extractor_mode='layer_norm',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment