Unverified Commit c8239c64 authored by moto's avatar moto Committed by GitHub
Browse files

Add wav2vec2 HuggingFace importer (#1530)

parent e6886a4d
......@@ -56,5 +56,5 @@ fi
(
set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers
)
......@@ -44,5 +44,5 @@ fi
(
set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers
)
......@@ -55,6 +55,14 @@ Factory Functions
.. autofunction:: wav2vec2_large_lv60k
.. currentmodule:: torchaudio.models.wav2vec2.utils
Utility Functions
-----------------
.. autofunction:: import_huggingface_model
.. currentmodule:: torchaudio.models
:hidden:`WaveRNN`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
{
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2Model"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": false,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "group",
"feat_proj_dropout": 0.1,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 12,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 12,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
{
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForCTC"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": false,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "group",
"feat_proj_dropout": 0.1,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 12,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 12,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
{
"activation_dropout": 0.0,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2Model"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": false,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_norm": "group",
"feat_proj_dropout": 0.1,
"final_dropout": 0.0,
"freeze_feat_extract_train": true,
"gradient_checkpointing": true,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"layerdrop": 0.05,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.05,
"mask_time_selection": "static",
"model_type": "wav2vec2",
"no_mask_channel_overlap": false,
"no_mask_time_overlap": false,
"num_attention_heads": 12,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 12,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
{
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForCTC"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
{
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForCTC"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
{
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForCTC"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": false,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "group",
"feat_proj_dropout": 0.1,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
{
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2Model"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
{
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForCTC"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 36
}
{
"activation_dropout": 0.0,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2Model"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"final_dropout": 0.0,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.075,
"mask_time_selection": "static",
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
{
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2Model"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": false,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "group",
"feat_proj_dropout": 0.1,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.5.1",
"vocab_size": 32
}
import os
import json
from transformers import Wav2Vec2Model
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
def _main():
keys = [
# pretrained
"facebook/wav2vec2-base",
"facebook/wav2vec2-large",
"facebook/wav2vec2-large-lv60",
"facebook/wav2vec2-base-10k-voxpopuli",
"facebook/wav2vec2-large-xlsr-53",
# finetuned
"facebook/wav2vec2-base-960h",
"facebook/wav2vec2-large-960h",
"facebook/wav2vec2-large-960h-lv60",
"facebook/wav2vec2-large-960h-lv60-self",
"facebook/wav2vec2-large-xlsr-53-german",
]
for key in keys:
path = os.path.join(_THIS_DIR, f'{key}.json')
print('Generating ', path)
cfg = Wav2Vec2Model.from_pretrained(key).config
cfg = json.loads(cfg.to_json_string())
del cfg['_name_or_path']
with open(path, 'w') as file_:
file_.write(json.dumps(cfg, indent=4, sort_keys=True))
file_.write('\n')
if __name__ == '__main__':
_main()
import json
import torch
from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
from torchaudio.models.wav2vec2.utils import import_huggingface_model
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
get_asset_path,
skipIfNoModule,
TorchaudioTestCase,
)
def _load_config(*paths):
with open(f'{get_asset_path("wav2vec2", "huggingface", *paths)}.json', 'r') as file_:
return json.load(file_)
# Pretrained
HF_BASE = _load_config('facebook', 'wav2vec2-base')
HF_LARGE = _load_config('facebook', 'wav2vec2-large')
HF_LARGE_LV60 = _load_config('facebook', 'wav2vec2-large-lv60')
HF_LARGE_XLSR_53 = _load_config('facebook', 'wav2vec2-large-xlsr-53')
HF_BASE_10K_VOXPOPULI = _load_config('facebook', 'wav2vec2-base-10k-voxpopuli')
# Finetuned
HF_BASE_960H = _load_config('facebook', 'wav2vec2-base-960h')
HF_LARGE_960H = _load_config('facebook', 'wav2vec2-large-960h')
HF_LARGE_LV60_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60')
HF_LARGE_LV60_SELF_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60-self')
HF_LARGE_XLSR_DE = _load_config('facebook', 'wav2vec2-large-xlsr-53-german')
# Config and corresponding factory functions
HF_CONFIGS = [
# pretrained
(HF_BASE, wav2vec2_base),
(HF_LARGE, wav2vec2_large),
(HF_LARGE_LV60, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_53, wav2vec2_large_lv60k),
(HF_BASE_10K_VOXPOPULI, wav2vec2_base),
# finetuned
(HF_BASE_960H, wav2vec2_base),
(HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_large_lv60k),
]
@skipIfNoModule('transformers')
class TestHFIntegration(TorchaudioTestCase):
"""Test the process of importing the models from Hugging Face Transformers
Test methods in this test suite check the following things
1. Models loaded with Hugging Face Transformers cane be imported.
2. The same model can be recreated without Hugging Face Transformers.
"""
def _get_model(self, config):
# Helper function to avoid importing transformers on module scope.
# Normally, we use `is_module_available` helper function to check if
# the library is available, and import it on module scope if available.
# However, somehow, once "transformers" is imported, `is_module_available`
# starts to fail. Therefore, we defer importing "transformers" until
# the actual tests are started.
from transformers.models.wav2vec2 import (
Wav2Vec2Config,
Wav2Vec2ForCTC,
)
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
@parameterized.expand([cfg[:1] for cfg in HF_CONFIGS])
def test_import(self, config):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
torch.manual_seed(0)
# FeatureExtractor
x = torch.randn(3, 1024)
ref = original.wav2vec2.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1])
ref = original.wav2vec2.feature_projection(x)
hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size'])
ref = original.wav2vec2.encoder.pos_conv_embed(x)
hyp = imported.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
# Encoder Transformer Layer
for original_, imported_ in zip(original.wav2vec2.encoder.layers, imported.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
ref, = original_(x, attention_mask=mask, output_attentions=False)
hyp = imported_(x, mask)
self.assertEqual(ref, hyp)
# The whole Encoder Transformer
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
ref = original.wav2vec2.encoder(x).last_hidden_state
hyp = imported.encoder.transformer(x)
self.assertEqual(ref, hyp)
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x)
hyp = imported.encoder.readout(x)
self.assertEqual(ref, hyp)
# The whole model without mask
x = torch.randn(3, 1024)
ref = original(x).logits
hyp, _ = imported(x)
self.assertEqual(ref, hyp)
# The whole model without mask
batch_size, num_frames = 3, 1024
x = torch.randn(batch_size, num_frames)
ref = original(x).logits
hyp, _ = imported(x)
self.assertEqual(ref, hyp)
# The whole model with mask
batch_size, num_frames = 3, 1024
x = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
mask = torch.arange(num_frames).expand(batch_size, num_frames) < lengths[:, None]
ref = original(x, attention_mask=mask).logits
hyp, output_lengths = imported(x, lengths)
for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])
@parameterized.expand(HF_CONFIGS)
def test_recreate(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
torch.manual_seed(0)
# FeatureExtractor
x = torch.randn(3, 1024)
ref, _ = imported.feature_extractor(x, None)
hyp, _ = reloaded.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1])
ref = imported.encoder.feature_projection(x)
hyp = reloaded.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size'])
ref = imported.encoder.transformer.pos_conv_embed(x)
hyp = reloaded.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
# Encoder Transformer Layer
for imported_, reloaded_ in zip(imported.encoder.transformer.layers, reloaded.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
ref = imported_(x, mask)
hyp = reloaded_(x, mask)
self.assertEqual(ref, hyp)
# The whole Encoder Transformer
# TODO: Add mask pattern. Expected mask shapes and values are different.
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
ref = imported.encoder.transformer(x)
hyp = reloaded.encoder.transformer(x)
self.assertEqual(ref, hyp)
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = imported.encoder.readout(x)
hyp = reloaded.encoder.readout(x)
self.assertEqual(ref, hyp)
# The whole model
x = torch.randn(3, 1024)
ref, _ = imported(x)
hyp, _ = reloaded(x)
self.assertEqual(ref, hyp)
......@@ -4,10 +4,12 @@ from .model import (
wav2vec2_large,
wav2vec2_large_lv60k,
)
from . import utils
__all__ = [
'Wav2Vec2Model',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
'utils',
]
from .import_huggingface import import_huggingface_model
__all__ = [
'import_huggingface_model',
]
"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format.
"""
import logging
from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model
_LG = logging.getLogger(__name__)
def _get_config(cfg):
config = {
'extractor_mode': f'{cfg.feat_extract_norm}_norm',
'extractor_conv_layer_config': list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
'extractor_conv_bias': cfg.conv_bias,
'encoder_embed_dim': cfg.hidden_size,
'encoder_projection_dropout': cfg.feat_proj_dropout,
'encoder_pos_conv_kernel': cfg.num_conv_pos_embeddings,
'encoder_pos_conv_groups': cfg.num_conv_pos_embedding_groups,
'encoder_num_layers': cfg.num_hidden_layers,
'encoder_num_heads': cfg.num_attention_heads,
'encoder_attention_dropout': cfg.attention_dropout,
'encoder_ff_interm_features': cfg.intermediate_size,
'encoder_ff_interm_dropout': cfg.activation_dropout,
'encoder_dropout': cfg.hidden_dropout,
'encoder_layer_norm_first': cfg.do_stable_layer_norm,
'encoder_layer_drop': cfg.layerdrop,
'encoder_num_out': cfg.vocab_size,
}
return config
def _build(config, original):
imported = _get_model(**config)
imported.feature_extractor.load_state_dict(original.wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(original.wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(original.wav2vec2.encoder.state_dict())
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
return imported
def import_huggingface_model(original: Module) -> Wav2Vec2Model:
"""Import wav2vec2 model from Hugging Face's `Transformers`_.
Args:
original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``.
Returns:
Wav2Vec2Model: Imported model.
Example
>>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = import_huggingface_model(original)
>>>
>>> waveforms, _ = torchaudio.load("audio.wav")
>>> logits, _ = model(waveforms)
.. _Transformers: https://huggingface.co/transformers/
"""
_LG.info('Importing model.')
if original.__class__.__name__ != 'Wav2Vec2ForCTC':
_LG.warning('The model is not an instance of Wav2Vec2ForCTC')
_LG.info('Loading model configuration.')
config = _get_config(original.config)
_LG.debug(' - config: %s', config)
_LG.info('Building model.')
imported = _build(config, original)
return imported
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment