Unverified Commit b2e9f1e4 authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Split pretraining and finetuning factory functions (#1783)

* [BC-Breaking] Split pretraining and finetuning factory functions

Previously, factory functions of wav2vec2 only generated the architecture
for the fine-tuning architecture used in wav2ve2 paper for ASR task.
That is, pre-training architecture + Linear module, and it did not
provide a straightforward way to generate architectures for pre-training.

The goal of the original implementation was to allow the inference of
wav2vec2 in non-Python environment via TorchScript. Now we would like to
expand it to pre-training/fine-tuning and HuBERT model as well.

Therefore, we need to have factory functions for both pre-training and
fine-tuning. This commit introduces new factory functions and separate
functions for pre-training and fine-tuning.

1. New functions for ASR fine-tuning.

We introdcue `wav2vec2_asr_XXX` functions which generates the architecture
used for the fine-tuning task in wav2vec2 paper. *1

2. Re-purpse the old functions

The existing functions, `wav2vec2_XXX`, now generates the architecture with
pre-trainig module only. (no Linear module)

Note
*1 This architecture is just one way to define architecture for fine-tuning
and it is not universal definition. The new `wav2vec2_asr_XXX` functions are
designed to provide these specific fine-tuning configuration and they are not
meant to support generic architecture for downstream task.
parent cf0adb28
......@@ -91,6 +91,21 @@ wav2vec2_large_lv60k
.. autofunction:: wav2vec2_large_lv60k
wav2vec2_asr_base
^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_base
wav2vec2_asr_large
^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_large
wav2vec2_asr_large_lv60k
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_large_lv60k
.. currentmodule:: torchaudio.models.wav2vec2.utils
Utility Functions
......
import json
import sys
import torch
from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
......@@ -27,7 +31,7 @@ def _name_func(testcase_func, i, param):
return f'{testcase_func.__name__}_{i}_{param[0][1].__name__}'
# Pretrined (not fine-tuned) models
# Pretraining (not fine-tuned) models
BASE = _load_config('wav2vec_small')
LARGE = _load_config('libri960_big')
LARGE_LV60K = _load_config('wav2vec_vox_new')
......@@ -39,17 +43,17 @@ LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h')
LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h')
# Config and corresponding factory functions
PRETRAINED_CONFIGS = parameterized.expand([
PRETRAINING_CONFIGS = parameterized.expand([
(BASE, wav2vec2_base),
(LARGE, wav2vec2_large),
(LARGE_LV60K, wav2vec2_large_lv60k),
(XLSR_53_56K, wav2vec2_large_lv60k),
], name_func=_name_func)
FINETUNED_CONFIGS = parameterized.expand([
(BASE_960H, wav2vec2_base),
(LARGE_960H, wav2vec2_large),
(LARGE_LV60K_960H, wav2vec2_large_lv60k),
(LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k),
(BASE_960H, wav2vec2_asr_base),
(LARGE_960H, wav2vec2_asr_large),
(LARGE_LV60K_960H, wav2vec2_asr_large_lv60k),
(LARGE_LV60K_SELF_960H, wav2vec2_asr_large_lv60k),
], name_func=_name_func)
......@@ -61,7 +65,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
1. Models loaded with fairseq cane be imported.
2. The same model can be recreated without fairseq.
"""
def _get_model(self, config, num_out):
def _get_model(self, config, num_out=None):
import copy
from omegaconf import OmegaConf
from fairseq.models.wav2vec.wav2vec2 import (
......@@ -81,31 +85,36 @@ class TestFairseqIntegration(TorchaudioTestCase):
return Wav2Vec2Model(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected configuration: {config["_name"]}')
@PRETRAINED_CONFIGS
def test_import_pretrained_model(self, config, _):
"""Pretrained wav2vec2 models from fairseq can be imported and yields the same results"""
num_out = 28
@PRETRAINING_CONFIGS
def test_import_pretraining_model(self, config, _):
"""Wav2vec2 pretraining models from fairseq can be imported and yields the same results"""
batch_size, num_frames = 3, 1024
original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original, 28).eval()
atol = 1.1e-05 if sys.platform == "darwin" else 1e-05
# macOS CI jobs fails dues to very small descrepency
# AssertionError: False is not true : Tensors failed to compare as equal!
# With rtol=1.3e-06 and atol=1e-05, found 1 element(s) (out of 6144)
# whose difference(s) exceeded the margin of error (including 0 nan comparisons).
# The greatest difference was 1.0967254638671875e-05 (-0.12493154406547546 vs.
# -0.12494251132011414), which occurred at index (1, 1, 169).
original = self._get_model(config).eval()
imported = import_fairseq_model(original).eval()
x = torch.randn(batch_size, num_frames)
hyp, _ = imported.extract_features(x)
refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs['layer_results']):
self.assertEqual(hyp[i], ref.transpose(0, 1))
self.assertEqual(hyp[i], ref.transpose(0, 1), atol=atol, rtol=1.3e-06)
@PRETRAINED_CONFIGS
def test_recreate_pretrained_model(self, config, factory_func):
"""Imported pretrained models can be recreated via a factory function without fairseq."""
num_out = 28
@PRETRAINING_CONFIGS
def test_recreate_pretraining_model(self, config, factory_func):
"""Imported pretraining models can be recreated via a factory function without fairseq."""
batch_size, num_frames = 3, 1024
original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original, 28).eval()
original = self._get_model(config).eval()
imported = import_fairseq_model(original).eval()
reloaded = factory_func(num_out=num_out)
reloaded = factory_func()
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
......
......@@ -2,6 +2,9 @@ import json
import torch
from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
......@@ -47,11 +50,11 @@ PRETRAIN_CONFIGS = parameterized.expand([
(HF_BASE_10K_VOXPOPULI, wav2vec2_base),
], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([
(HF_BASE_960H, wav2vec2_base),
(HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_large_lv60k),
(HF_BASE_960H, wav2vec2_asr_base),
(HF_LARGE_960H, wav2vec2_asr_large),
(HF_LARGE_LV60_960H, wav2vec2_asr_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_asr_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_asr_large_lv60k),
], name_func=_name_func)
......@@ -81,7 +84,7 @@ class TestHFIntegration(TorchaudioTestCase):
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected arch: {config["architectures"]}')
def _test_import_pretrain(self, original, imported, config, ):
def _test_import_pretrain(self, original, imported, config):
torch.manual_seed(0)
# FeatureExtractor
x = torch.randn(3, 1024)
......@@ -115,7 +118,7 @@ class TestHFIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp)
def _test_import_finetune(self, original, imported, config):
# Readout
# Aux
x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x)
hyp = imported.aux(x)
......@@ -193,7 +196,8 @@ class TestHFIntegration(TorchaudioTestCase):
ref = imported.encoder.transformer(x)
hyp = reloaded.encoder.transformer(x)
self.assertEqual(ref, hyp)
# Readout
# Aux
if imported.aux is not None:
x = torch.randn(3, 10, config["hidden_size"])
ref = imported.aux(x)
hyp = reloaded.aux(x)
......@@ -208,7 +212,7 @@ class TestHFIntegration(TorchaudioTestCase):
def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.aux.out_features)
reloaded = factory_func()
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
......
......@@ -2,6 +2,9 @@ import torch
import torch.nn.functional as F
from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
......@@ -19,16 +22,22 @@ def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
factory_funcs = parameterized.expand([
pretrain_factory_funcs = parameterized.expand([
(wav2vec2_base, ),
(wav2vec2_large, ),
(wav2vec2_large_lv60k, ),
], name_func=_name_func)
finetune_factory_funcs = parameterized.expand([
(wav2vec2_asr_base, ),
(wav2vec2_asr_large, ),
(wav2vec2_asr_large_lv60k, ),
], name_func=_name_func)
class TestWav2Vec2Model(TorchaudioTestCase):
def _smoke_test(self, device, dtype):
model = wav2vec2_base(num_out=32)
def _smoke_test(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
......@@ -44,19 +53,23 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@parameterized.expand([(torch.float32, ), (torch.float64, )])
def test_cpu_smoke_test(self, dtype):
self._smoke_test(torch.device('cpu'), dtype)
model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype)
model = wav2vec2_asr_base(num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
self._smoke_test(torch.device('cuda'), dtype)
model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype)
model = wav2vec2_asr_base(num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype)
@factory_funcs
def test_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail"""
def _feature_extractor_test(self, model):
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
model.eval()
num_layers = len(model.encoder.transformer.layers)
torch.manual_seed(0)
......@@ -80,14 +93,19 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(all_features[i], features[i])
assert lengths_.shape == torch.Size([batch_size])
@factory_funcs
def test_batch_consistency(self, factory_func):
"""Results from sigle process and batched process should be reasonably close
"""
batch_size, max_frames = 5, 5 * 1024
@pretrain_factory_funcs
def test_pretrain_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail"""
self._feature_extractor_test(factory_func())
model = factory_func(num_out=32).eval()
@finetune_factory_funcs
def test_finetune_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail"""
self._feature_extractor_test(factory_func(num_out=32))
def _test_batch_consistency(self, model):
model.eval()
batch_size, max_frames = 5, 5 * 1024
torch.manual_seed(0)
waveforms = torch.randn(batch_size, max_frames)
input_lengths = torch.tensor([i * 3200 for i in range(1, 6)])
......@@ -105,24 +123,43 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# We allow max atol=0.005 -> 0.5%
self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0)
@factory_funcs
def test_zero_length(self, factory_func):
"""Passing zero length should not fail"""
model = factory_func(num_out=32).eval()
@pretrain_factory_funcs
def test_pretrain_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close
"""
self._test_batch_consistency(factory_func())
@pretrain_factory_funcs
def test_finetune_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close
"""
self._test_batch_consistency(factory_func())
def _test_zero_length(self, model):
model.eval()
torch.manual_seed(0)
batch_size = 3
waveforms = torch.randn(batch_size, 1024)
input_lengths = torch.zeros(batch_size)
_, output_lengths = model(waveforms, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
_, output_lengths = model.extract_features(waveforms, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
@factory_funcs
def test_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
batch_size, num_frames = 3, 1024
@pretrain_factory_funcs
def test_pretrain_zero_length(self, factory_func):
"""Passing zero length should not fail"""
self._test_zero_length(factory_func())
@finetune_factory_funcs
def test_finetune_zero_length(self, factory_func):
"""Passing zero length should not fail"""
self._test_zero_length(factory_func(num_out=32))
model = factory_func(num_out=32).eval()
def _test_torchscript(self, model):
model.eval()
batch_size, num_frames = 3, 1024
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
......@@ -137,13 +174,19 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@factory_funcs
@skipIfNoQengine
def test_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
batch_size, num_frames = 3, 1024
@pretrain_factory_funcs
def test_pretrain_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
self._test_torchscript(factory_func())
model = factory_func(num_out=32).eval()
@finetune_factory_funcs
def test_finetune_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
self._test_torchscript(factory_func(num_out=32))
def _test_quantize_smoke_test(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
......@@ -159,13 +202,22 @@ class TestWav2Vec2Model(TorchaudioTestCase):
_, _ = quantized(waveforms, lengths)
@factory_funcs
@pretrain_factory_funcs
@skipIfNoQengine
def test_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
batch_size, num_frames = 3, 1024
def test_pretrain_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
self._test_quantize_smoke_test(factory_func())
model = factory_func(num_out=32).eval()
@finetune_factory_funcs
@skipIfNoQengine
def test_finetune_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
self._test_quantize_smoke_test(factory_func(num_out=32))
def _test_quantize_torchscript(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
......@@ -188,3 +240,15 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@pretrain_factory_funcs
@skipIfNoQengine
def test_pretrain_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func())
@finetune_factory_funcs
@skipIfNoQengine
def test_finetune_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func(num_out=32))
......@@ -5,6 +5,9 @@ from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2
from .wav2vec2 import (
Wav2Vec2Model,
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
......@@ -18,6 +21,9 @@ __all__ = [
'ConvTasNet',
'DeepSpeech',
'Wav2Vec2Model',
'wav2vec2_asr_base',
'wav2vec2_asr_large',
'wav2vec2_asr_large_lv60k',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
......
from .model import (
Wav2Vec2Model,
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
......@@ -8,6 +11,9 @@ from . import utils
__all__ = [
'Wav2Vec2Model',
'wav2vec2_asr_base',
'wav2vec2_asr_large',
'wav2vec2_asr_large_lv60k',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
......
......@@ -617,8 +617,6 @@ def _get_encoder(
Probability to drop each encoder layer during training.
This option corresponds to "layerdrop" from fairseq.
Expected values are 0.1 for both Base and Large arch.
num_out (int):
The dimension of the output. The number of labels.
See Also:
* "encoder_embed_dim"
......
......@@ -116,7 +116,7 @@ def _get_model(
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
aux_num_out: int,
aux_num_out: Optional[int],
) -> Wav2Vec2Model:
if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
......@@ -138,34 +138,53 @@ def _get_model(
layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop,
)
aux = torch.nn.Linear(
in_features=encoder_embed_dim,
out_features=aux_num_out,
)
aux = None
if aux_num_out is not None:
aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
return Wav2Vec2Model(feature_extractor, encoder, aux)
def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Base" configuration from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`].
def wav2vec2_base() -> Wav2Vec2Model:
"""Build wav2vec2 model with "base" configuration
This is one of the model architecture used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
Returns:
Wav2Vec2Model:
"""
return _get_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
aux_num_out=None,
)
def wav2vec2_asr_base(num_out: int) -> Wav2Vec2Model:
"""Build "base" wav2vec2 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args:
num_out: int
The number of output labels.
Returns:
Wav2Vec2Model: The resulting model.
Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters.
>>> from torchaudio.models.wav2vec2.utils import import_huggingface_model
>>>
>>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = import_huggingface_model(original)
>>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt")
>>>
>>> # Session 2 - Load model and the parameters
>>> model = wav2vec2_base(num_out=32)
>>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
Wav2Vec2Model:
"""
return _get_model(
extractor_mode="group_norm",
......@@ -187,27 +206,47 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
)
def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Large" configuration from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`].
def wav2vec2_large() -> Wav2Vec2Model:
"""Build wav2vec2 model with "large" configuration
This is one of the model architecture used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
Returns:
Wav2Vec2Model:
"""
return _get_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
aux_num_out=None,
)
def wav2vec2_asr_large(num_out: int) -> Wav2Vec2Model:
"""Build "large" wav2vec2.0 model with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args:
num_out: int
The number of output labels.
Returns:
Wav2Vec2Model: The resulting model.
Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters.
>>> from torchaudio.models.wav2vec2.utils import import_huggingface_model
>>>
>>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
>>> model = import_huggingface_model(original)
>>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt")
>>>
>>> # Session 2 - Load model and the parameters
>>> model = wav2vec2_large(num_out=32)
>>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
Wav2Vec2Model:
"""
return _get_model(
extractor_mode="group_norm",
......@@ -229,8 +268,40 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
)
def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Large LV-60k" configuration from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`].
def wav2vec2_large_lv60k() -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Large LV-60k" configuration
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
Returns:
Wav2Vec2Model: The resulting model.
"""
return _get_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
aux_num_out=None,
)
def wav2vec2_asr_large_lv60k(num_out: int) -> Wav2Vec2Model:
"""Build "Large LV-60k" wav2vec2.0 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args:
num_out: int
......@@ -238,18 +309,6 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
Returns:
Wav2Vec2Model: The resulting model.
Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters.
>>> from torchaudio.models.wav2vec2.utils import import_huggingface_model
>>>
>>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
>>> model = import_huggingface_model(original)
>>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt")
>>>
>>> # Session 2 - Load model and the parameters
>>> model = wav2vec2_large_lv60k(num_out=32)
>>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
"""
return _get_model(
extractor_mode="layer_norm",
......
......@@ -3,14 +3,13 @@
For this module to work, you need `fairseq`.
"""
import re
from typing import Optional
from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model
def _parse_config(w2v_model, num_out):
def _parse_config(w2v_model):
encoder = w2v_model.encoder
conv_layers = w2v_model.feature_extractor.conv_layers
......@@ -46,7 +45,6 @@ def _parse_config(w2v_model, num_out):
'encoder_dropout': encoder.layers[0].dropout3.p,
'encoder_layer_norm_first': encoder.layer_norm_first,
'encoder_layer_drop': encoder.layerdrop,
'aux_num_out': num_out,
}
return config
......@@ -108,7 +106,8 @@ def _map_key(key):
if match:
return f"encoder.transformer.layers.{match.group(1)}.final_layer_norm.{match.group(2)}"
match = re.match(r"proj\.(weight|bias)", key)
# Encoder - Readout layer
# Auxiliary Module
# Only relevant when loading fine-tuned models
if match:
return f"aux.{match.group(1)}"
raise ValueError(f'Unexpected key: {key_}')
......@@ -123,9 +122,7 @@ def _convert_state_dict(state_dict):
return converted
def import_fairseq_model(
original: Module,
num_out: Optional[int] = None) -> Wav2Vec2Model:
def import_fairseq_model(original: Module) -> Wav2Vec2Model:
"""Build Wav2Vec2Model from pretrained parameters published by `fairseq`_.
Args:
......@@ -133,9 +130,6 @@ def import_fairseq_model(
An instance of fairseq's Wav2Vec2.0 model class.
Either ``fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder`` or
``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``.
num_out (int or None, optional):
The number of output labels. Required only when the original model is
an instance of ``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``.
Returns:
Wav2Vec2Model: Imported model.
......@@ -147,7 +141,7 @@ def import_fairseq_model(
>>> model_file = 'wav2vec_small.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original, num_out=28)
>>> imported = import_fairseq_model(original)
>>>
>>> # Perform feature extraction
>>> waveform, _ = torchaudio.load('audio.wav')
......@@ -179,12 +173,7 @@ def import_fairseq_model(
"""
class_ = original.__class__.__name__
if class_ == 'Wav2Vec2Model':
if num_out is None:
raise ValueError(
'When importing a pretrained model without readout layer, '
'`num_out` argument must be given.'
)
return _import_pretrained(original, num_out)
return _import_pretrained(original)
if class_ == 'Wav2VecEncoder':
return _import_finetuned(original)
raise ValueError(
......@@ -192,14 +181,14 @@ def import_fairseq_model(
def _import_finetuned(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model, original.proj.out_features)
model = _get_model(**config)
config = _parse_config(original.w2v_model)
model = _get_model(**config, aux_num_out=original.proj.out_features)
model.load_state_dict(_convert_state_dict(original.state_dict()))
return model
def _import_pretrained(original: Module, num_out: int) -> Wav2Vec2Model:
config = _parse_config(original, num_out)
model = _get_model(**config)
def _import_pretrained(original: Module) -> Wav2Vec2Model:
config = _parse_config(original)
model = _get_model(**config, aux_num_out=None)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model
......@@ -26,18 +26,21 @@ def _get_config(cfg):
'encoder_dropout': cfg.hidden_dropout,
'encoder_layer_norm_first': cfg.do_stable_layer_norm,
'encoder_layer_drop': cfg.layerdrop,
'aux_num_out': cfg.vocab_size,
}
return config
def _build(config, original):
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
aux_num_out = original.config.vocab_size
wav2vec2 = original.wav2vec2
else:
_LG.warning(
'The model is not an instance of Wav2Vec2ForCTC. '
'"lm_head" module is not imported.')
aux_num_out = None
wav2vec2 = original
imported = _get_model(**config)
imported = _get_model(**config, aux_num_out=aux_num_out)
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
......@@ -67,8 +70,6 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model:
.. _Transformers: https://huggingface.co/transformers/
"""
_LG.info('Importing model.')
if original.__class__.__name__ != 'Wav2Vec2ForCTC':
_LG.warning('The model is not an instance of Wav2Vec2ForCTC')
_LG.info('Loading model configuration.')
config = _get_config(original.config)
_LG.debug(' - config: %s', config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment