Unverified Commit b2e9f1e4 authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Split pretraining and finetuning factory functions (#1783)

* [BC-Breaking] Split pretraining and finetuning factory functions

Previously, factory functions of wav2vec2 only generated the architecture
for the fine-tuning architecture used in wav2ve2 paper for ASR task.
That is, pre-training architecture + Linear module, and it did not
provide a straightforward way to generate architectures for pre-training.

The goal of the original implementation was to allow the inference of
wav2vec2 in non-Python environment via TorchScript. Now we would like to
expand it to pre-training/fine-tuning and HuBERT model as well.

Therefore, we need to have factory functions for both pre-training and
fine-tuning. This commit introduces new factory functions and separate
functions for pre-training and fine-tuning.

1. New functions for ASR fine-tuning.

We introdcue `wav2vec2_asr_XXX` functions which generates the architecture
used for the fine-tuning task in wav2vec2 paper. *1

2. Re-purpse the old functions

The existing functions, `wav2vec2_XXX`, now generates the architecture with
pre-trainig module only. (no Linear module)

Note
*1 This architecture is just one way to define architecture for fine-tuning
and it is not universal definition. The new `wav2vec2_asr_XXX` functions are
designed to provide these specific fine-tuning configuration and they are not
meant to support generic architecture for downstream task.
parent cf0adb28
...@@ -91,6 +91,21 @@ wav2vec2_large_lv60k ...@@ -91,6 +91,21 @@ wav2vec2_large_lv60k
.. autofunction:: wav2vec2_large_lv60k .. autofunction:: wav2vec2_large_lv60k
wav2vec2_asr_base
^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_base
wav2vec2_asr_large
^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_large
wav2vec2_asr_large_lv60k
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_large_lv60k
.. currentmodule:: torchaudio.models.wav2vec2.utils .. currentmodule:: torchaudio.models.wav2vec2.utils
Utility Functions Utility Functions
......
import json import json
import sys
import torch import torch
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
...@@ -27,7 +31,7 @@ def _name_func(testcase_func, i, param): ...@@ -27,7 +31,7 @@ def _name_func(testcase_func, i, param):
return f'{testcase_func.__name__}_{i}_{param[0][1].__name__}' return f'{testcase_func.__name__}_{i}_{param[0][1].__name__}'
# Pretrined (not fine-tuned) models # Pretraining (not fine-tuned) models
BASE = _load_config('wav2vec_small') BASE = _load_config('wav2vec_small')
LARGE = _load_config('libri960_big') LARGE = _load_config('libri960_big')
LARGE_LV60K = _load_config('wav2vec_vox_new') LARGE_LV60K = _load_config('wav2vec_vox_new')
...@@ -39,17 +43,17 @@ LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h') ...@@ -39,17 +43,17 @@ LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h')
LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h') LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h')
# Config and corresponding factory functions # Config and corresponding factory functions
PRETRAINED_CONFIGS = parameterized.expand([ PRETRAINING_CONFIGS = parameterized.expand([
(BASE, wav2vec2_base), (BASE, wav2vec2_base),
(LARGE, wav2vec2_large), (LARGE, wav2vec2_large),
(LARGE_LV60K, wav2vec2_large_lv60k), (LARGE_LV60K, wav2vec2_large_lv60k),
(XLSR_53_56K, wav2vec2_large_lv60k), (XLSR_53_56K, wav2vec2_large_lv60k),
], name_func=_name_func) ], name_func=_name_func)
FINETUNED_CONFIGS = parameterized.expand([ FINETUNED_CONFIGS = parameterized.expand([
(BASE_960H, wav2vec2_base), (BASE_960H, wav2vec2_asr_base),
(LARGE_960H, wav2vec2_large), (LARGE_960H, wav2vec2_asr_large),
(LARGE_LV60K_960H, wav2vec2_large_lv60k), (LARGE_LV60K_960H, wav2vec2_asr_large_lv60k),
(LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k), (LARGE_LV60K_SELF_960H, wav2vec2_asr_large_lv60k),
], name_func=_name_func) ], name_func=_name_func)
...@@ -61,7 +65,7 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -61,7 +65,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
1. Models loaded with fairseq cane be imported. 1. Models loaded with fairseq cane be imported.
2. The same model can be recreated without fairseq. 2. The same model can be recreated without fairseq.
""" """
def _get_model(self, config, num_out): def _get_model(self, config, num_out=None):
import copy import copy
from omegaconf import OmegaConf from omegaconf import OmegaConf
from fairseq.models.wav2vec.wav2vec2 import ( from fairseq.models.wav2vec.wav2vec2 import (
...@@ -81,31 +85,36 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -81,31 +85,36 @@ class TestFairseqIntegration(TorchaudioTestCase):
return Wav2Vec2Model(Wav2Vec2Config(**config)) return Wav2Vec2Model(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected configuration: {config["_name"]}') raise ValueError(f'Unexpected configuration: {config["_name"]}')
@PRETRAINED_CONFIGS @PRETRAINING_CONFIGS
def test_import_pretrained_model(self, config, _): def test_import_pretraining_model(self, config, _):
"""Pretrained wav2vec2 models from fairseq can be imported and yields the same results""" """Wav2vec2 pretraining models from fairseq can be imported and yields the same results"""
num_out = 28
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
atol = 1.1e-05 if sys.platform == "darwin" else 1e-05
original = self._get_model(config, num_out).eval() # macOS CI jobs fails dues to very small descrepency
imported = import_fairseq_model(original, 28).eval() # AssertionError: False is not true : Tensors failed to compare as equal!
# With rtol=1.3e-06 and atol=1e-05, found 1 element(s) (out of 6144)
# whose difference(s) exceeded the margin of error (including 0 nan comparisons).
# The greatest difference was 1.0967254638671875e-05 (-0.12493154406547546 vs.
# -0.12494251132011414), which occurred at index (1, 1, 169).
original = self._get_model(config).eval()
imported = import_fairseq_model(original).eval()
x = torch.randn(batch_size, num_frames) x = torch.randn(batch_size, num_frames)
hyp, _ = imported.extract_features(x) hyp, _ = imported.extract_features(x)
refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1) refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs['layer_results']): for i, (ref, _) in enumerate(refs['layer_results']):
self.assertEqual(hyp[i], ref.transpose(0, 1)) self.assertEqual(hyp[i], ref.transpose(0, 1), atol=atol, rtol=1.3e-06)
@PRETRAINED_CONFIGS @PRETRAINING_CONFIGS
def test_recreate_pretrained_model(self, config, factory_func): def test_recreate_pretraining_model(self, config, factory_func):
"""Imported pretrained models can be recreated via a factory function without fairseq.""" """Imported pretraining models can be recreated via a factory function without fairseq."""
num_out = 28
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
original = self._get_model(config, num_out).eval() original = self._get_model(config).eval()
imported = import_fairseq_model(original, 28).eval() imported = import_fairseq_model(original).eval()
reloaded = factory_func(num_out=num_out) reloaded = factory_func()
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
......
...@@ -2,6 +2,9 @@ import json ...@@ -2,6 +2,9 @@ import json
import torch import torch
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
...@@ -47,11 +50,11 @@ PRETRAIN_CONFIGS = parameterized.expand([ ...@@ -47,11 +50,11 @@ PRETRAIN_CONFIGS = parameterized.expand([
(HF_BASE_10K_VOXPOPULI, wav2vec2_base), (HF_BASE_10K_VOXPOPULI, wav2vec2_base),
], name_func=_name_func) ], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([ FINETUNE_CONFIGS = parameterized.expand([
(HF_BASE_960H, wav2vec2_base), (HF_BASE_960H, wav2vec2_asr_base),
(HF_LARGE_960H, wav2vec2_large), (HF_LARGE_960H, wav2vec2_asr_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k), (HF_LARGE_LV60_960H, wav2vec2_asr_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k), (HF_LARGE_LV60_SELF_960H, wav2vec2_asr_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_large_lv60k), (HF_LARGE_XLSR_DE, wav2vec2_asr_large_lv60k),
], name_func=_name_func) ], name_func=_name_func)
...@@ -81,7 +84,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -81,7 +84,7 @@ class TestHFIntegration(TorchaudioTestCase):
return Wav2Vec2ForCTC(Wav2Vec2Config(**config)) return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected arch: {config["architectures"]}') raise ValueError(f'Unexpected arch: {config["architectures"]}')
def _test_import_pretrain(self, original, imported, config, ): def _test_import_pretrain(self, original, imported, config):
torch.manual_seed(0) torch.manual_seed(0)
# FeatureExtractor # FeatureExtractor
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
...@@ -115,7 +118,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -115,7 +118,7 @@ class TestHFIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
def _test_import_finetune(self, original, imported, config): def _test_import_finetune(self, original, imported, config):
# Readout # Aux
x = torch.randn(3, 10, config["hidden_size"]) x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x) ref = original.lm_head(x)
hyp = imported.aux(x) hyp = imported.aux(x)
...@@ -193,7 +196,8 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -193,7 +196,8 @@ class TestHFIntegration(TorchaudioTestCase):
ref = imported.encoder.transformer(x) ref = imported.encoder.transformer(x)
hyp = reloaded.encoder.transformer(x) hyp = reloaded.encoder.transformer(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Readout # Aux
if imported.aux is not None:
x = torch.randn(3, 10, config["hidden_size"]) x = torch.randn(3, 10, config["hidden_size"])
ref = imported.aux(x) ref = imported.aux(x)
hyp = reloaded.aux(x) hyp = reloaded.aux(x)
...@@ -208,7 +212,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -208,7 +212,7 @@ class TestHFIntegration(TorchaudioTestCase):
def test_recreate_pretrain(self, config, factory_func): def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers.""" """Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval() imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.aux.out_features) reloaded = factory_func()
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
self._test_recreate(imported, reloaded, config) self._test_recreate(imported, reloaded, config)
......
...@@ -2,6 +2,9 @@ import torch ...@@ -2,6 +2,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
...@@ -19,16 +22,22 @@ def _name_func(testcase_func, i, param): ...@@ -19,16 +22,22 @@ def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}" return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
factory_funcs = parameterized.expand([ pretrain_factory_funcs = parameterized.expand([
(wav2vec2_base, ), (wav2vec2_base, ),
(wav2vec2_large, ), (wav2vec2_large, ),
(wav2vec2_large_lv60k, ), (wav2vec2_large_lv60k, ),
], name_func=_name_func) ], name_func=_name_func)
finetune_factory_funcs = parameterized.expand([
(wav2vec2_asr_base, ),
(wav2vec2_asr_large, ),
(wav2vec2_asr_large_lv60k, ),
], name_func=_name_func)
class TestWav2Vec2Model(TorchaudioTestCase): class TestWav2Vec2Model(TorchaudioTestCase):
def _smoke_test(self, device, dtype): def _smoke_test(self, model, device, dtype):
model = wav2vec2_base(num_out=32)
model = model.to(device=device, dtype=dtype) model = model.to(device=device, dtype=dtype)
model = model.eval() model = model.eval()
...@@ -44,19 +53,23 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -44,19 +53,23 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@parameterized.expand([(torch.float32, ), (torch.float64, )]) @parameterized.expand([(torch.float32, ), (torch.float64, )])
def test_cpu_smoke_test(self, dtype): def test_cpu_smoke_test(self, dtype):
self._smoke_test(torch.device('cpu'), dtype) model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype)
model = wav2vec2_asr_base(num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )]) @parameterized.expand([(torch.float32, ), (torch.float64, )])
@skipIfNoCuda @skipIfNoCuda
def test_cuda_smoke_test(self, dtype): def test_cuda_smoke_test(self, dtype):
self._smoke_test(torch.device('cuda'), dtype) model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype)
model = wav2vec2_asr_base(num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype)
@factory_funcs def _feature_extractor_test(self, model):
def test_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail"""
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval() model.eval()
num_layers = len(model.encoder.transformer.layers) num_layers = len(model.encoder.transformer.layers)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -80,14 +93,19 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -80,14 +93,19 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(all_features[i], features[i]) self.assertEqual(all_features[i], features[i])
assert lengths_.shape == torch.Size([batch_size]) assert lengths_.shape == torch.Size([batch_size])
@factory_funcs @pretrain_factory_funcs
def test_batch_consistency(self, factory_func): def test_pretrain_feature_extractor_test(self, factory_func):
"""Results from sigle process and batched process should be reasonably close """`extract_features` method does not fail"""
""" self._feature_extractor_test(factory_func())
batch_size, max_frames = 5, 5 * 1024
model = factory_func(num_out=32).eval() @finetune_factory_funcs
def test_finetune_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail"""
self._feature_extractor_test(factory_func(num_out=32))
def _test_batch_consistency(self, model):
model.eval()
batch_size, max_frames = 5, 5 * 1024
torch.manual_seed(0) torch.manual_seed(0)
waveforms = torch.randn(batch_size, max_frames) waveforms = torch.randn(batch_size, max_frames)
input_lengths = torch.tensor([i * 3200 for i in range(1, 6)]) input_lengths = torch.tensor([i * 3200 for i in range(1, 6)])
...@@ -105,24 +123,43 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -105,24 +123,43 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# We allow max atol=0.005 -> 0.5% # We allow max atol=0.005 -> 0.5%
self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0) self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0)
@factory_funcs @pretrain_factory_funcs
def test_zero_length(self, factory_func): def test_pretrain_batch_consistency(self, factory_func):
"""Passing zero length should not fail""" """Results from single process and batched process should be reasonably close
model = factory_func(num_out=32).eval() """
self._test_batch_consistency(factory_func())
@pretrain_factory_funcs
def test_finetune_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close
"""
self._test_batch_consistency(factory_func())
def _test_zero_length(self, model):
model.eval()
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 3 batch_size = 3
waveforms = torch.randn(batch_size, 1024) waveforms = torch.randn(batch_size, 1024)
input_lengths = torch.zeros(batch_size) input_lengths = torch.zeros(batch_size)
_, output_lengths = model(waveforms, input_lengths) _, output_lengths = model(waveforms, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths) self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
_, output_lengths = model.extract_features(waveforms, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
@factory_funcs @pretrain_factory_funcs
def test_torchscript(self, factory_func): def test_pretrain_zero_length(self, factory_func):
"""Wav2Vec2Model should be scriptable""" """Passing zero length should not fail"""
batch_size, num_frames = 3, 1024 self._test_zero_length(factory_func())
@finetune_factory_funcs
def test_finetune_zero_length(self, factory_func):
"""Passing zero length should not fail"""
self._test_zero_length(factory_func(num_out=32))
model = factory_func(num_out=32).eval() def _test_torchscript(self, model):
model.eval()
batch_size, num_frames = 3, 1024
torch.manual_seed(0) torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
...@@ -137,13 +174,19 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -137,13 +174,19 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(hyp_out, ref_out) self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len) self.assertEqual(hyp_len, ref_len)
@factory_funcs @pretrain_factory_funcs
@skipIfNoQengine def test_pretrain_torchscript(self, factory_func):
def test_quantize(self, factory_func): """Wav2Vec2Model should be scriptable"""
"""Wav2Vec2Model should support basic quantization""" self._test_torchscript(factory_func())
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval() @finetune_factory_funcs
def test_finetune_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
self._test_torchscript(factory_func(num_out=32))
def _test_quantize_smoke_test(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook # Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
...@@ -159,13 +202,22 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -159,13 +202,22 @@ class TestWav2Vec2Model(TorchaudioTestCase):
_, _ = quantized(waveforms, lengths) _, _ = quantized(waveforms, lengths)
@factory_funcs @pretrain_factory_funcs
@skipIfNoQengine @skipIfNoQengine
def test_quantize_torchscript(self, factory_func): def test_pretrain_quantize(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable""" """Wav2Vec2Model should support basic quantization"""
batch_size, num_frames = 3, 1024 self._test_quantize_smoke_test(factory_func())
model = factory_func(num_out=32).eval() @finetune_factory_funcs
@skipIfNoQengine
def test_finetune_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
self._test_quantize_smoke_test(factory_func(num_out=32))
def _test_quantize_torchscript(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook # Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
...@@ -188,3 +240,15 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -188,3 +240,15 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.assertEqual(hyp_out, ref_out) self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len) self.assertEqual(hyp_len, ref_len)
@pretrain_factory_funcs
@skipIfNoQengine
def test_pretrain_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func())
@finetune_factory_funcs
@skipIfNoQengine
def test_finetune_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func(num_out=32))
...@@ -5,6 +5,9 @@ from .deepspeech import DeepSpeech ...@@ -5,6 +5,9 @@ from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2 from .tacotron2 import Tacotron2, tacotron2
from .wav2vec2 import ( from .wav2vec2 import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
...@@ -18,6 +21,9 @@ __all__ = [ ...@@ -18,6 +21,9 @@ __all__ = [
'ConvTasNet', 'ConvTasNet',
'DeepSpeech', 'DeepSpeech',
'Wav2Vec2Model', 'Wav2Vec2Model',
'wav2vec2_asr_base',
'wav2vec2_asr_large',
'wav2vec2_asr_large_lv60k',
'wav2vec2_base', 'wav2vec2_base',
'wav2vec2_large', 'wav2vec2_large',
'wav2vec2_large_lv60k', 'wav2vec2_large_lv60k',
......
from .model import ( from .model import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_asr_base,
wav2vec2_asr_large,
wav2vec2_asr_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
...@@ -8,6 +11,9 @@ from . import utils ...@@ -8,6 +11,9 @@ from . import utils
__all__ = [ __all__ = [
'Wav2Vec2Model', 'Wav2Vec2Model',
'wav2vec2_asr_base',
'wav2vec2_asr_large',
'wav2vec2_asr_large_lv60k',
'wav2vec2_base', 'wav2vec2_base',
'wav2vec2_large', 'wav2vec2_large',
'wav2vec2_large_lv60k', 'wav2vec2_large_lv60k',
......
...@@ -617,8 +617,6 @@ def _get_encoder( ...@@ -617,8 +617,6 @@ def _get_encoder(
Probability to drop each encoder layer during training. Probability to drop each encoder layer during training.
This option corresponds to "layerdrop" from fairseq. This option corresponds to "layerdrop" from fairseq.
Expected values are 0.1 for both Base and Large arch. Expected values are 0.1 for both Base and Large arch.
num_out (int):
The dimension of the output. The number of labels.
See Also: See Also:
* "encoder_embed_dim" * "encoder_embed_dim"
......
...@@ -116,7 +116,7 @@ def _get_model( ...@@ -116,7 +116,7 @@ def _get_model(
encoder_dropout: float, encoder_dropout: float,
encoder_layer_norm_first: bool, encoder_layer_norm_first: bool,
encoder_layer_drop: float, encoder_layer_drop: float,
aux_num_out: int, aux_num_out: Optional[int],
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
if extractor_conv_layer_config is None: if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
...@@ -138,34 +138,53 @@ def _get_model( ...@@ -138,34 +138,53 @@ def _get_model(
layer_norm_first=encoder_layer_norm_first, layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop, layer_drop=encoder_layer_drop,
) )
aux = torch.nn.Linear( aux = None
in_features=encoder_embed_dim, if aux_num_out is not None:
out_features=aux_num_out, aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
)
return Wav2Vec2Model(feature_extractor, encoder, aux) return Wav2Vec2Model(feature_extractor, encoder, aux)
def wav2vec2_base(num_out: int) -> Wav2Vec2Model: def wav2vec2_base() -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Base" configuration from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]. """Build wav2vec2 model with "base" configuration
This is one of the model architecture used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
Returns:
Wav2Vec2Model:
"""
return _get_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
aux_num_out=None,
)
def wav2vec2_asr_base(num_out: int) -> Wav2Vec2Model:
"""Build "base" wav2vec2 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args: Args:
num_out: int num_out: int
The number of output labels. The number of output labels.
Returns: Returns:
Wav2Vec2Model: The resulting model. Wav2Vec2Model:
Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters.
>>> from torchaudio.models.wav2vec2.utils import import_huggingface_model
>>>
>>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = import_huggingface_model(original)
>>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt")
>>>
>>> # Session 2 - Load model and the parameters
>>> model = wav2vec2_base(num_out=32)
>>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
""" """
return _get_model( return _get_model(
extractor_mode="group_norm", extractor_mode="group_norm",
...@@ -187,27 +206,47 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model: ...@@ -187,27 +206,47 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
) )
def wav2vec2_large(num_out: int) -> Wav2Vec2Model: def wav2vec2_large() -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Large" configuration from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]. """Build wav2vec2 model with "large" configuration
This is one of the model architecture used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
Returns:
Wav2Vec2Model:
"""
return _get_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
aux_num_out=None,
)
def wav2vec2_asr_large(num_out: int) -> Wav2Vec2Model:
"""Build "large" wav2vec2.0 model with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args: Args:
num_out: int num_out: int
The number of output labels. The number of output labels.
Returns: Returns:
Wav2Vec2Model: The resulting model. Wav2Vec2Model:
Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters.
>>> from torchaudio.models.wav2vec2.utils import import_huggingface_model
>>>
>>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
>>> model = import_huggingface_model(original)
>>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt")
>>>
>>> # Session 2 - Load model and the parameters
>>> model = wav2vec2_large(num_out=32)
>>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
""" """
return _get_model( return _get_model(
extractor_mode="group_norm", extractor_mode="group_norm",
...@@ -229,8 +268,40 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model: ...@@ -229,8 +268,40 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
) )
def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model: def wav2vec2_large_lv60k() -> Wav2Vec2Model:
"""Build wav2vec2.0 model with "Large LV-60k" configuration from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]. """Build wav2vec2.0 model with "Large LV-60k" configuration
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for pretraining.
Returns:
Wav2Vec2Model: The resulting model.
"""
return _get_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
aux_num_out=None,
)
def wav2vec2_asr_large_lv60k(num_out: int) -> Wav2Vec2Model:
"""Build "Large LV-60k" wav2vec2.0 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
Args: Args:
num_out: int num_out: int
...@@ -238,18 +309,6 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model: ...@@ -238,18 +309,6 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
Returns: Returns:
Wav2Vec2Model: The resulting model. Wav2Vec2Model: The resulting model.
Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters.
>>> from torchaudio.models.wav2vec2.utils import import_huggingface_model
>>>
>>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
>>> model = import_huggingface_model(original)
>>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt")
>>>
>>> # Session 2 - Load model and the parameters
>>> model = wav2vec2_large_lv60k(num_out=32)
>>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
""" """
return _get_model( return _get_model(
extractor_mode="layer_norm", extractor_mode="layer_norm",
......
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
For this module to work, you need `fairseq`. For this module to work, you need `fairseq`.
""" """
import re import re
from typing import Optional
from torch.nn import Module from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model from ..model import Wav2Vec2Model, _get_model
def _parse_config(w2v_model, num_out): def _parse_config(w2v_model):
encoder = w2v_model.encoder encoder = w2v_model.encoder
conv_layers = w2v_model.feature_extractor.conv_layers conv_layers = w2v_model.feature_extractor.conv_layers
...@@ -46,7 +45,6 @@ def _parse_config(w2v_model, num_out): ...@@ -46,7 +45,6 @@ def _parse_config(w2v_model, num_out):
'encoder_dropout': encoder.layers[0].dropout3.p, 'encoder_dropout': encoder.layers[0].dropout3.p,
'encoder_layer_norm_first': encoder.layer_norm_first, 'encoder_layer_norm_first': encoder.layer_norm_first,
'encoder_layer_drop': encoder.layerdrop, 'encoder_layer_drop': encoder.layerdrop,
'aux_num_out': num_out,
} }
return config return config
...@@ -108,7 +106,8 @@ def _map_key(key): ...@@ -108,7 +106,8 @@ def _map_key(key):
if match: if match:
return f"encoder.transformer.layers.{match.group(1)}.final_layer_norm.{match.group(2)}" return f"encoder.transformer.layers.{match.group(1)}.final_layer_norm.{match.group(2)}"
match = re.match(r"proj\.(weight|bias)", key) match = re.match(r"proj\.(weight|bias)", key)
# Encoder - Readout layer # Auxiliary Module
# Only relevant when loading fine-tuned models
if match: if match:
return f"aux.{match.group(1)}" return f"aux.{match.group(1)}"
raise ValueError(f'Unexpected key: {key_}') raise ValueError(f'Unexpected key: {key_}')
...@@ -123,9 +122,7 @@ def _convert_state_dict(state_dict): ...@@ -123,9 +122,7 @@ def _convert_state_dict(state_dict):
return converted return converted
def import_fairseq_model( def import_fairseq_model(original: Module) -> Wav2Vec2Model:
original: Module,
num_out: Optional[int] = None) -> Wav2Vec2Model:
"""Build Wav2Vec2Model from pretrained parameters published by `fairseq`_. """Build Wav2Vec2Model from pretrained parameters published by `fairseq`_.
Args: Args:
...@@ -133,9 +130,6 @@ def import_fairseq_model( ...@@ -133,9 +130,6 @@ def import_fairseq_model(
An instance of fairseq's Wav2Vec2.0 model class. An instance of fairseq's Wav2Vec2.0 model class.
Either ``fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder`` or Either ``fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder`` or
``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``. ``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``.
num_out (int or None, optional):
The number of output labels. Required only when the original model is
an instance of ``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``.
Returns: Returns:
Wav2Vec2Model: Imported model. Wav2Vec2Model: Imported model.
...@@ -147,7 +141,7 @@ def import_fairseq_model( ...@@ -147,7 +141,7 @@ def import_fairseq_model(
>>> model_file = 'wav2vec_small.pt' >>> model_file = 'wav2vec_small.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0] >>> original = model[0]
>>> imported = import_fairseq_model(original, num_out=28) >>> imported = import_fairseq_model(original)
>>> >>>
>>> # Perform feature extraction >>> # Perform feature extraction
>>> waveform, _ = torchaudio.load('audio.wav') >>> waveform, _ = torchaudio.load('audio.wav')
...@@ -179,12 +173,7 @@ def import_fairseq_model( ...@@ -179,12 +173,7 @@ def import_fairseq_model(
""" """
class_ = original.__class__.__name__ class_ = original.__class__.__name__
if class_ == 'Wav2Vec2Model': if class_ == 'Wav2Vec2Model':
if num_out is None: return _import_pretrained(original)
raise ValueError(
'When importing a pretrained model without readout layer, '
'`num_out` argument must be given.'
)
return _import_pretrained(original, num_out)
if class_ == 'Wav2VecEncoder': if class_ == 'Wav2VecEncoder':
return _import_finetuned(original) return _import_finetuned(original)
raise ValueError( raise ValueError(
...@@ -192,14 +181,14 @@ def import_fairseq_model( ...@@ -192,14 +181,14 @@ def import_fairseq_model(
def _import_finetuned(original: Module) -> Wav2Vec2Model: def _import_finetuned(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model, original.proj.out_features) config = _parse_config(original.w2v_model)
model = _get_model(**config) model = _get_model(**config, aux_num_out=original.proj.out_features)
model.load_state_dict(_convert_state_dict(original.state_dict())) model.load_state_dict(_convert_state_dict(original.state_dict()))
return model return model
def _import_pretrained(original: Module, num_out: int) -> Wav2Vec2Model: def _import_pretrained(original: Module) -> Wav2Vec2Model:
config = _parse_config(original, num_out) config = _parse_config(original)
model = _get_model(**config) model = _get_model(**config, aux_num_out=None)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model return model
...@@ -26,18 +26,21 @@ def _get_config(cfg): ...@@ -26,18 +26,21 @@ def _get_config(cfg):
'encoder_dropout': cfg.hidden_dropout, 'encoder_dropout': cfg.hidden_dropout,
'encoder_layer_norm_first': cfg.do_stable_layer_norm, 'encoder_layer_norm_first': cfg.do_stable_layer_norm,
'encoder_layer_drop': cfg.layerdrop, 'encoder_layer_drop': cfg.layerdrop,
'aux_num_out': cfg.vocab_size,
} }
return config return config
def _build(config, original): def _build(config, original):
if original.__class__.__name__ == 'Wav2Vec2ForCTC': if original.__class__.__name__ == 'Wav2Vec2ForCTC':
aux_num_out = original.config.vocab_size
wav2vec2 = original.wav2vec2 wav2vec2 = original.wav2vec2
else: else:
_LG.warning(
'The model is not an instance of Wav2Vec2ForCTC. '
'"lm_head" module is not imported.')
aux_num_out = None
wav2vec2 = original wav2vec2 = original
imported = _get_model(**config, aux_num_out=aux_num_out)
imported = _get_model(**config)
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict()) imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict()) imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
...@@ -67,8 +70,6 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model: ...@@ -67,8 +70,6 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model:
.. _Transformers: https://huggingface.co/transformers/ .. _Transformers: https://huggingface.co/transformers/
""" """
_LG.info('Importing model.') _LG.info('Importing model.')
if original.__class__.__name__ != 'Wav2Vec2ForCTC':
_LG.warning('The model is not an instance of Wav2Vec2ForCTC')
_LG.info('Loading model configuration.') _LG.info('Loading model configuration.')
config = _get_config(original.config) config = _get_config(original.config)
_LG.debug(' - config: %s', config) _LG.debug(' - config: %s', config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment