Commit 0582e73c authored by moto's avatar moto
Browse files

Make the core wav2vec2 factory function public (#1829)

This commit makes the following changes
1. Make the factory function with full customizability public.
    i.e. `_get_model(...) -> wav2vec2_model(...)`.
2. Change the other architecture-specific factory functions so that they accept parameters not related to the model architecture (such as dropout).
    i.e. `wav2vec2_base() -> wav2vec2_base(encoder_projection_dropout, encoder_attention_dropout, encoder_ff_interm_dropout, ...)`

### Why?

While adding the pre-trained weight support, I realized that separating API for model construction and pre-trained support achieves simple code organization because of the good separation of concern. As mentioned in #1821, in this framework,
  1. Model implementation is responsible for computation logic,
  2. factory functions are responsible for customizability and model construction,
  3. and pre-trained weight API is responsible for constructing a model and loading pre-trained weights along with the complementary information (such as pre-processing and class labels).

(note: for simple models, combining 1 and 2 is also okay.)

This means that factory functions has to support all the customizability required by pre-trained weight API. The current implementation uses the internal function like `from .model import Wav2Vec2Model, _get_model`, which is a bit strange.

This PR rectifies it by making the mother factory function public.
This also clarifies the purpose of having the other factory functions as public API, which is just a syntax sugar for constructing un-trained model with specific architecture. So this commit also adds supplemental parameters to them.
parent 8f270d09
...@@ -75,6 +75,12 @@ Wav2Vec2Model ...@@ -75,6 +75,12 @@ Wav2Vec2Model
Factory Functions Factory Functions
----------------- -----------------
wav2vec2_model
^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_model
wav2vec2_base wav2vec2_base
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
......
...@@ -226,7 +226,7 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -226,7 +226,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
original = self._get_model(config, num_out).eval() original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original).eval() imported = import_fairseq_model(original).eval()
reloaded = factory_func(num_out=num_out) reloaded = factory_func(aux_num_out=num_out)
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
......
...@@ -221,7 +221,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -221,7 +221,7 @@ class TestHFIntegration(TorchaudioTestCase):
def test_recreate_finetune(self, config, factory_func): def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers.""" """Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval() imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.aux.out_features) reloaded = factory_func(aux_num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
self._test_recreate(imported, reloaded, config) self._test_recreate(imported, reloaded, config)
...@@ -67,7 +67,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -67,7 +67,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cpu_smoke_test(self, dtype): def test_cpu_smoke_test(self, dtype):
model = wav2vec2_base() model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype) self._smoke_test(model, torch.device('cpu'), dtype)
model = wav2vec2_ft_base(num_out=32) model = wav2vec2_ft_base(aux_num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype) self._smoke_test(model, torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )]) @parameterized.expand([(torch.float32, ), (torch.float64, )])
...@@ -75,7 +75,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -75,7 +75,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cuda_smoke_test(self, dtype): def test_cuda_smoke_test(self, dtype):
model = wav2vec2_base() model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype) self._smoke_test(model, torch.device('cuda'), dtype)
model = wav2vec2_ft_base(num_out=32) model = wav2vec2_ft_base(aux_num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype) self._smoke_test(model, torch.device('cuda'), dtype)
def _feature_extractor_test(self, model): def _feature_extractor_test(self, model):
...@@ -113,7 +113,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -113,7 +113,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@finetune_factory_funcs @finetune_factory_funcs
def test_finetune_feature_extractor_test(self, factory_func): def test_finetune_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail""" """`extract_features` method does not fail"""
self._feature_extractor_test(factory_func(num_out=32)) self._feature_extractor_test(factory_func(aux_num_out=32))
def _test_batch_consistency(self, model): def _test_batch_consistency(self, model):
model.eval() model.eval()
...@@ -166,7 +166,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -166,7 +166,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@finetune_factory_funcs @finetune_factory_funcs
def test_finetune_zero_length(self, factory_func): def test_finetune_zero_length(self, factory_func):
"""Passing zero length should not fail""" """Passing zero length should not fail"""
self._test_zero_length(factory_func(num_out=32)) self._test_zero_length(factory_func(aux_num_out=32))
def _test_torchscript(self, model): def _test_torchscript(self, model):
model.eval() model.eval()
...@@ -202,7 +202,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -202,7 +202,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.skipTest( self.skipTest(
'hubert_ft_xlarge is known to fail on Windows CI. ' 'hubert_ft_xlarge is known to fail on Windows CI. '
'See https://github.com/pytorch/pytorch/issues/65776') 'See https://github.com/pytorch/pytorch/issues/65776')
self._test_torchscript(factory_func(num_out=32)) self._test_torchscript(factory_func(aux_num_out=32))
def _test_quantize_smoke_test(self, model): def _test_quantize_smoke_test(self, model):
model.eval() model.eval()
...@@ -232,7 +232,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -232,7 +232,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@skipIfNoQengine @skipIfNoQengine
def test_finetune_quantize(self, factory_func): def test_finetune_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization""" """Wav2Vec2Model should support basic quantization"""
self._test_quantize_smoke_test(factory_func(num_out=32)) self._test_quantize_smoke_test(factory_func(aux_num_out=32))
def _test_quantize_torchscript(self, model): def _test_quantize_torchscript(self, model):
model.eval() model.eval()
...@@ -271,4 +271,4 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -271,4 +271,4 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@skipIfNoQengine @skipIfNoQengine
def test_finetune_quantize_torchscript(self, factory_func): def test_finetune_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable""" """Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func(num_out=32)) self._test_quantize_torchscript(factory_func(aux_num_out=32))
...@@ -5,6 +5,7 @@ from .deepspeech import DeepSpeech ...@@ -5,6 +5,7 @@ from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2 from .tacotron2 import Tacotron2, tacotron2
from .wav2vec2 import ( from .wav2vec2 import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_model,
wav2vec2_ft_base, wav2vec2_ft_base,
wav2vec2_ft_large, wav2vec2_ft_large,
wav2vec2_ft_large_lv60k, wav2vec2_ft_large_lv60k,
...@@ -46,6 +47,7 @@ __all__ = [ ...@@ -46,6 +47,7 @@ __all__ = [
'ConvTasNet', 'ConvTasNet',
'DeepSpeech', 'DeepSpeech',
'Wav2Vec2Model', 'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_ft_base', 'wav2vec2_ft_base',
'wav2vec2_ft_large', 'wav2vec2_ft_large',
'wav2vec2_ft_large_lv60k', 'wav2vec2_ft_large_lv60k',
......
from .model import ( from .model import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_model,
wav2vec2_ft_base, wav2vec2_ft_base,
wav2vec2_ft_large, wav2vec2_ft_large,
wav2vec2_ft_large_lv60k, wav2vec2_ft_large_lv60k,
...@@ -16,6 +17,7 @@ from . import utils ...@@ -16,6 +17,7 @@ from . import utils
__all__ = [ __all__ = [
'Wav2Vec2Model', 'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_ft_base', 'wav2vec2_ft_base',
'wav2vec2_ft_large', 'wav2vec2_ft_large',
'wav2vec2_ft_large_lv60k', 'wav2vec2_ft_large_lv60k',
......
...@@ -100,7 +100,7 @@ class Wav2Vec2Model(Module): ...@@ -100,7 +100,7 @@ class Wav2Vec2Model(Module):
return x, lengths return x, lengths
def _get_model( def wav2vec2_model(
extractor_mode: str, extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
extractor_conv_bias: bool, extractor_conv_bias: bool,
...@@ -118,6 +118,127 @@ def _get_model( ...@@ -118,6 +118,127 @@ def _get_model(
encoder_layer_drop: float, encoder_layer_drop: float,
aux_num_out: Optional[int], aux_num_out: Optional[int],
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build a custom Wav2Vec2Model
Note:
The "feature extractor" below corresponds to
`ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
in the original ``fairseq`` implementation.
This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] paper.
The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
and this is referred as "Transformer" in the paper.
Args:
extractor_mode (str): Operation mode of feature extractor.
Valid values are ``"group_norm"`` or ``"layer_norm"``.
If ``"group_norm"``, then a single normalization is applied
in the first convolution block. Otherwise, all the convolution
blocks will have layer normalization.
This option corresponds to ``extractor_mode`` from ``fairseq``.
extractor_conv_layer_config (list of integer tuples or None):
Configuration of convolution layers in feature extractor.
List of convolution configuration,
i.e. ``[(output_channel, kernel_size, stride), ...]``
If ``None`` is provided, then the following default value is used.
.. code-block:: python
[
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
]
This option corresponds to ``conv_feature_layers`` from ``fairseq``.
extractor_conv_bias (bool):
Whether to include bias term to each convolution operation.
This option corresponds to ``conv_bias`` from ``fairseq``.
encoder_embed_dim (int):
The dimension of embedding in encoder.
This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
encoder_projection_dropout (float):
The dropout probability applied after the input feature is projected
to ``encoder_embed_dim``.
This option corresponds to ``dropout_input`` from ``fairseq``.
encoder_pos_conv_kernel (int):
The kernel size of convolutional positional embeddings.
This option corresponds to ``conv_pos`` from ``fairseq``.
encoder_pos_conv_groups (int):
The number of groups of convolutional positional embeddings.
This option corresponds to ``conv_pos_groups`` from ``fairseq``.
encoder_num_layers (int):
The number of self attention layers in transformer block.
This option corresponds to ``encoder_layers`` from ``fairseq``.
encoder_num_heads (int):
The number of heads in self attention layers.
This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
encoder_attention_dropout (float):
The dropout probability applied after softmax in self-attention layer.
This option corresponds to ``attention_dropout`` from ``fairseq``.
encoder_ff_interm_features (int):
The dimension of hidden features in feed forward layer.
This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
encoder_ff_interm_dropout (float):
The dropout probability applied in feedforward layer.
This option correspinds to ``activation_dropout`` from ``fairseq``.
encoder_dropout (float):
The dropout probability applied at the end of feed forward layer.
This option corresponds to ``dropout`` from ``fairseq``.
encoder_layer_norm_first (bool):
Control the order of layer norm in transformer layer and each encoder layer.
If True, in transformer layer, layer norm is applied before features are fed
to encoder layers. In encoder layer, two layer norms are applied before and after
self attention.
If False, in transformer layer, layer norm is applied after features are fed
to encoder layers. In encoder layer, two layer norms are applied after self
attention, before and after feed forward.
This option corresponds to ``layer_norm_first`` from ``fairseq``.
encoder_layer_drop (float):
Probability to drop each encoder layer during training.
This option corresponds to ``layerdrop`` from ``fairseq``.
aux_num_out (int or None):
When provided, attach an extra linear layer on top of encoder, which can be
used for fine-tuning.
Returns:
Wav2Vec2Model:
The resulting model.
""" # noqa: E501
if extractor_conv_layer_config is None: if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
...@@ -144,251 +265,400 @@ def _get_model( ...@@ -144,251 +265,400 @@ def _get_model(
return Wav2Vec2Model(feature_extractor, encoder, aux) return Wav2Vec2Model(feature_extractor, encoder, aux)
def wav2vec2_base() -> Wav2Vec2Model: def wav2vec2_base(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build wav2vec2 model with "base" configuration """Build wav2vec2 model with "base" configuration
This is one of the model architecture used in *wav2vec 2.0* This is one of the model architecture used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining. [:footcite:`baevski2020wav2vec`] for pretraining.
Args:
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
Wav2Vec2Model: Wav2Vec2Model:
""" """
return _get_model( return wav2vec2_model(
extractor_mode="group_norm", extractor_mode="group_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=768, encoder_embed_dim=768,
encoder_projection_dropout=0.1, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=12, encoder_num_layers=12,
encoder_num_heads=12, encoder_num_heads=12,
encoder_attention_dropout=0.1, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=3072, encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=0.1, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.1, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False, encoder_layer_norm_first=False,
encoder_layer_drop=0.1, encoder_layer_drop=encoder_layer_drop,
aux_num_out=None, aux_num_out=None,
) )
def wav2vec2_ft_base(num_out: int) -> Wav2Vec2Model: def wav2vec2_ft_base(
aux_num_out: int,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "base" wav2vec2 with an extra linear module """Build "base" wav2vec2 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0* This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task. [:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args: Args:
num_out: int aux_num_out (int):
The number of output labels. The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
Wav2Vec2Model: Wav2Vec2Model:
""" """
return _get_model( return wav2vec2_model(
extractor_mode="group_norm", extractor_mode="group_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=768, encoder_embed_dim=768,
encoder_projection_dropout=0.1, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=12, encoder_num_layers=12,
encoder_num_heads=12, encoder_num_heads=12,
encoder_attention_dropout=0.1, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=3072, encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=0.1, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.1, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False, encoder_layer_norm_first=False,
encoder_layer_drop=0.1, encoder_layer_drop=encoder_layer_drop,
aux_num_out=num_out, aux_num_out=aux_num_out,
) )
def wav2vec2_large() -> Wav2Vec2Model: def wav2vec2_large(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build wav2vec2 model with "large" configuration """Build wav2vec2 model with "large" configuration
This is one of the model architecture used in *wav2vec 2.0* This is one of the model architecture used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining. [:footcite:`baevski2020wav2vec`] for pretraining.
Args:
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
Wav2Vec2Model: Wav2Vec2Model:
""" """
return _get_model( return wav2vec2_model(
extractor_mode="group_norm", extractor_mode="group_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=1024, encoder_embed_dim=1024,
encoder_projection_dropout=0.1, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=24, encoder_num_layers=24,
encoder_num_heads=16, encoder_num_heads=16,
encoder_attention_dropout=0.1, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096, encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.1, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False, encoder_layer_norm_first=False,
encoder_layer_drop=0.1, encoder_layer_drop=encoder_layer_drop,
aux_num_out=None, aux_num_out=None,
) )
def wav2vec2_ft_large(num_out: int) -> Wav2Vec2Model: def wav2vec2_ft_large(
aux_num_out: int,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "large" wav2vec2.0 model with an extra linear module """Build "large" wav2vec2.0 model with an extra linear module
This is one of the model architectures used in *wav2vec 2.0* This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task. [:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args: Args:
num_out: int aux_num_out (int):
The number of output labels. The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
Wav2Vec2Model: Wav2Vec2Model:
""" """
return _get_model( return wav2vec2_model(
extractor_mode="group_norm", extractor_mode="group_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=1024, encoder_embed_dim=1024,
encoder_projection_dropout=0.1, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=24, encoder_num_layers=24,
encoder_num_heads=16, encoder_num_heads=16,
encoder_attention_dropout=0.1, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096, encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.1, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False, encoder_layer_norm_first=False,
encoder_layer_drop=0.1, encoder_layer_drop=encoder_layer_drop,
aux_num_out=num_out, aux_num_out=aux_num_out,
) )
def wav2vec2_large_lv60k() -> Wav2Vec2Model: def wav2vec2_large_lv60k(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Large LV-60k" configuration """Build wav2vec2.0 model with "Large LV-60k" configuration
This is one of the model architectures used in *wav2vec 2.0* This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining. [:footcite:`baevski2020wav2vec`] for pretraining.
Args:
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
Wav2Vec2Model: The resulting model. Wav2Vec2Model:
""" """
return _get_model( return wav2vec2_model(
extractor_mode="layer_norm", extractor_mode="layer_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=True, extractor_conv_bias=True,
encoder_embed_dim=1024, encoder_embed_dim=1024,
encoder_projection_dropout=0.1, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=24, encoder_num_layers=24,
encoder_num_heads=16, encoder_num_heads=16,
encoder_attention_dropout=0.0, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096, encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.0, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True, encoder_layer_norm_first=True,
encoder_layer_drop=0.1, encoder_layer_drop=encoder_layer_drop,
aux_num_out=None, aux_num_out=None,
) )
def wav2vec2_ft_large_lv60k(num_out: int) -> Wav2Vec2Model: def wav2vec2_ft_large_lv60k(
aux_num_out: int,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "Large LV-60k" wav2vec2.0 with an extra linear module """Build "Large LV-60k" wav2vec2.0 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0* This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task. [:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args: Args:
num_out: int aux_num_out (int):
The number of output labels. The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
Wav2Vec2Model: The resulting model. Wav2Vec2Model: The resulting model.
""" """
return _get_model( return wav2vec2_model(
extractor_mode="layer_norm", extractor_mode="layer_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=True, extractor_conv_bias=True,
encoder_embed_dim=1024, encoder_embed_dim=1024,
encoder_projection_dropout=0.1, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=24, encoder_num_layers=24,
encoder_num_heads=16, encoder_num_heads=16,
encoder_attention_dropout=0.0, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096, encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.0, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True, encoder_layer_norm_first=True,
encoder_layer_drop=0.1, encoder_layer_drop=encoder_layer_drop,
aux_num_out=num_out, aux_num_out=aux_num_out,
) )
def hubert_base() -> Wav2Vec2Model: def hubert_base(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.05,
) -> Wav2Vec2Model:
"""Build HuBERT model with "Base" configuration """Build HuBERT model with "Base" configuration
This is one of the model architectures used in *HuBERT* This is one of the model architectures used in *HuBERT*
[:footcite:`hsu2021hubert`] for pretraining. [:footcite:`hsu2021hubert`] for pretraining.
Args:
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
HuBERT: The resulting model. HuBERT: The resulting model.
""" """
return _get_model( return wav2vec2_model(
extractor_mode='group_norm', extractor_mode='group_norm',
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=768, encoder_embed_dim=768,
encoder_projection_dropout=0.1, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=12, encoder_num_layers=12,
encoder_num_heads=12, encoder_num_heads=12,
encoder_attention_dropout=0.1, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=3072, encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=0.0, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.1, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False, encoder_layer_norm_first=False,
encoder_layer_drop=0.05, encoder_layer_drop=encoder_layer_drop,
aux_num_out=None, aux_num_out=None,
) )
def hubert_large() -> Wav2Vec2Model: def hubert_large(
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
) -> Wav2Vec2Model:
"""Build HuBERT model with "Large" configuration """Build HuBERT model with "Large" configuration
This is one of the model architectures used in *HuBERT* This is one of the model architectures used in *HuBERT*
[:footcite:`hsu2021hubert`] for pretraining. [:footcite:`hsu2021hubert`] for pretraining.
Args:
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
HuBERT: The resulting model. HuBERT: The resulting model.
""" """
return _get_model( return wav2vec2_model(
extractor_mode='layer_norm', extractor_mode='layer_norm',
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=1024, encoder_embed_dim=1024,
encoder_projection_dropout=0.0, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=24, encoder_num_layers=24,
encoder_num_heads=16, encoder_num_heads=16,
encoder_attention_dropout=0.0, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096, encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.0, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.0, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True, encoder_layer_norm_first=True,
encoder_layer_drop=0.0, encoder_layer_drop=encoder_layer_drop,
aux_num_out=None, aux_num_out=None,
) )
def hubert_ft_large(num_out) -> Wav2Vec2Model: def hubert_ft_large(
aux_num_out: int,
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "Large" HuBERT model with an extra linear module """Build "Large" HuBERT model with an extra linear module
...@@ -396,89 +666,134 @@ def hubert_ft_large(num_out) -> Wav2Vec2Model: ...@@ -396,89 +666,134 @@ def hubert_ft_large(num_out) -> Wav2Vec2Model:
[:footcite:`hsu2021hubert`] for fine-tuning for ASR task. [:footcite:`hsu2021hubert`] for fine-tuning for ASR task.
Args: Args:
num_out: int aux_num_out (int):
The number of output labels. The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
Wav2Vec2Model: Wav2Vec2Model:
""" """
return _get_model( return wav2vec2_model(
extractor_mode='layer_norm', extractor_mode='layer_norm',
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=1024, encoder_embed_dim=1024,
encoder_projection_dropout=0.0, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=24, encoder_num_layers=24,
encoder_num_heads=16, encoder_num_heads=16,
encoder_attention_dropout=0.0, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096, encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.0, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True, encoder_layer_norm_first=True,
encoder_layer_drop=0.1, encoder_layer_drop=encoder_layer_drop,
aux_num_out=num_out, aux_num_out=aux_num_out,
) )
def hubert_xlarge() -> Wav2Vec2Model: def hubert_xlarge(
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
) -> Wav2Vec2Model:
"""Build HuBERT model with "extra large" configuration """Build HuBERT model with "extra large" configuration
This is one of the model architectures used in *HuBERT* This is one of the model architectures used in *HuBERT*
[:footcite:`hsu2021hubert`] for pretraining. [:footcite:`hsu2021hubert`] for pretraining.
Args:
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
HuBERT: The resulting model. HuBERT: The resulting model.
""" """
return _get_model( return wav2vec2_model(
extractor_mode='layer_norm', extractor_mode='layer_norm',
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=1280, encoder_embed_dim=1280,
encoder_projection_dropout=0.0, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=48, encoder_num_layers=48,
encoder_num_heads=16, encoder_num_heads=16,
encoder_attention_dropout=0.0, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=5120, encoder_ff_interm_features=5120,
encoder_ff_interm_dropout=0.0, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.0, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True, encoder_layer_norm_first=True,
encoder_layer_drop=0.0, encoder_layer_drop=encoder_layer_drop,
aux_num_out=None, aux_num_out=None,
) )
def hubert_ft_xlarge(num_out) -> Wav2Vec2Model: def hubert_ft_xlarge(
aux_num_out: int,
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
) -> Wav2Vec2Model:
"""Build "extra large" HuBERT model with an extra linear module """Build "extra large" HuBERT model with an extra linear module
This is one of the model architecture used in *HuBERT* This is one of the model architecture used in *HuBERT*
[:footcite:`hsu2021hubert`] for fine-tuning for ASR task. [:footcite:`hsu2021hubert`] for fine-tuning for ASR task.
Args: Args:
num_out: int aux_num_out (int):
The number of output labels. The output dimension of the extra linear module.
encoder_projection_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_attention_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_ff_interm_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_dropout (float):
See :py:func:`wav2vec2_model`.
encoder_layer_drop (float):
See :py:func:`wav2vec2_model`.
Returns: Returns:
Wav2Vec2Model: The resulting model. Wav2Vec2Model: The resulting model.
""" """
return _get_model( return wav2vec2_model(
extractor_mode='layer_norm', extractor_mode='layer_norm',
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=1280, encoder_embed_dim=1280,
encoder_projection_dropout=0.0, encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128, encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16, encoder_pos_conv_groups=16,
encoder_num_layers=48, encoder_num_layers=48,
encoder_num_heads=16, encoder_num_heads=16,
encoder_attention_dropout=0.0, encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=5120, encoder_ff_interm_features=5120,
encoder_ff_interm_dropout=0.1, encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=0.0, encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True, encoder_layer_norm_first=True,
encoder_layer_drop=0.1, encoder_layer_drop=encoder_layer_drop,
aux_num_out=num_out, aux_num_out=aux_num_out,
) )
...@@ -3,7 +3,7 @@ from typing import Dict, Tuple, Any, Optional ...@@ -3,7 +3,7 @@ from typing import Dict, Tuple, Any, Optional
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
from .model import _get_model, Wav2Vec2Model from .model import wav2vec2_model, Wav2Vec2Model
__all__ = [] __all__ = []
...@@ -67,7 +67,7 @@ class Wav2Vec2PretrainedModelBundle: ...@@ -67,7 +67,7 @@ class Wav2Vec2PretrainedModelBundle:
Args: Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
""" """
model = _get_model(**self._params) model = wav2vec2_model(**self._params)
url = f'https://download.pytorch.org/models/audio/{self._path}' url = f'https://download.pytorch.org/models/audio/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs) state_dict = load_state_dict_from_url(url, **dl_kwargs)
......
...@@ -6,7 +6,7 @@ import re ...@@ -6,7 +6,7 @@ import re
from torch.nn import Module from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model from ..model import Wav2Vec2Model, wav2vec2_model
def _parse_config(w2v_model): def _parse_config(w2v_model):
...@@ -190,27 +190,27 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model: ...@@ -190,27 +190,27 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model:
def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model: def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model) config = _parse_config(original.w2v_model)
model = _get_model(**config, aux_num_out=original.proj.out_features) model = wav2vec2_model(**config, aux_num_out=original.proj.out_features)
model.load_state_dict(_convert_state_dict(original.state_dict())) model.load_state_dict(_convert_state_dict(original.state_dict()))
return model return model
def _import_wav2vec2_pretraining(original: Module) -> Wav2Vec2Model: def _import_wav2vec2_pretraining(original: Module) -> Wav2Vec2Model:
config = _parse_config(original) config = _parse_config(original)
model = _get_model(**config, aux_num_out=None) model = wav2vec2_model(**config, aux_num_out=None)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model return model
def _import_hubert_finetuning(original: Module) -> Wav2Vec2Model: def _import_hubert_finetuning(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model) config = _parse_config(original.w2v_model)
model = _get_model(**config, aux_num_out=original.proj.out_features) model = wav2vec2_model(**config, aux_num_out=original.proj.out_features)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model return model
def _import_hubert_pretraining(original: Module) -> Wav2Vec2Model: def _import_hubert_pretraining(original: Module) -> Wav2Vec2Model:
config = _parse_config(original) config = _parse_config(original)
model = _get_model(**config, aux_num_out=None) model = wav2vec2_model(**config, aux_num_out=None)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model return model
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
from torch.nn import Module from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model from ..model import Wav2Vec2Model, wav2vec2_model
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
...@@ -40,7 +40,7 @@ def _build(config, original): ...@@ -40,7 +40,7 @@ def _build(config, original):
'"lm_head" module is not imported.') '"lm_head" module is not imported.')
aux_num_out = None aux_num_out = None
wav2vec2 = original wav2vec2 = original
imported = _get_model(**config, aux_num_out=aux_num_out) imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict()) imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict()) imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment