Unverified Commit e9cab8f8 authored by moto's avatar moto Committed by GitHub
Browse files

Fix HF model integration (#1781)

* Fix HF model integration

Previously, when testing wav2vec models from HF transformers, all the model were
instantiated as `Wav2Vec2ForCTC` class, while some of them were supposed to be
`Wav2Vec2Model`.

Fixing this revealed that model importer cannot correctly handle `Wav2Vec2Model` import.

This PR fixes these issues.
parent 1b4b82e0
......@@ -39,14 +39,14 @@ HF_LARGE_LV60_SELF_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60-sel
HF_LARGE_XLSR_DE = _load_config('facebook', 'wav2vec2-large-xlsr-53-german')
# Config and corresponding factory functions
HF_CONFIGS = parameterized.expand([
# pretrained
PRETRAIN_CONFIGS = parameterized.expand([
(HF_BASE, wav2vec2_base),
(HF_LARGE, wav2vec2_large),
(HF_LARGE_LV60, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_53, wav2vec2_large_lv60k),
(HF_BASE_10K_VOXPOPULI, wav2vec2_base),
# finetuned
], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([
(HF_BASE_960H, wav2vec2_base),
(HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
......@@ -72,34 +72,34 @@ class TestHFIntegration(TorchaudioTestCase):
# the actual tests are started.
from transformers.models.wav2vec2 import (
Wav2Vec2Config,
Wav2Vec2Model,
Wav2Vec2ForCTC,
)
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
@HF_CONFIGS
def test_import(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
if config['architectures'] == ['Wav2Vec2Model']:
return Wav2Vec2Model(Wav2Vec2Config(**config))
if config['architectures'] == ['Wav2Vec2ForCTC']:
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected arch: {config["architectures"]}')
def _test_import_pretrain(self, original, imported, config, ):
torch.manual_seed(0)
# FeatureExtractor
x = torch.randn(3, 1024)
ref = original.wav2vec2.feature_extractor(x).transpose(1, 2)
ref = original.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1])
ref = original.wav2vec2.feature_projection(x)[0]
ref = original.feature_projection(x)[0]
hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size'])
ref = original.wav2vec2.encoder.pos_conv_embed(x)
ref = original.encoder.pos_conv_embed(x)
hyp = imported.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
# Encoder Transformer Layer
for original_, imported_ in zip(original.wav2vec2.encoder.layers, imported.encoder.transformer.layers):
for original_, imported_ in zip(original.encoder.layers, imported.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
......@@ -110,9 +110,11 @@ class TestHFIntegration(TorchaudioTestCase):
# The whole Encoder Transformer
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
ref = original.wav2vec2.encoder(x).last_hidden_state
ref = original.encoder(x).last_hidden_state
hyp = imported.encoder.transformer(x)
self.assertEqual(ref, hyp)
def _test_import_finetune(self, original, imported, config):
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x)
......@@ -142,15 +144,22 @@ class TestHFIntegration(TorchaudioTestCase):
for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])
@HF_CONFIGS
def test_recreate(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
@PRETRAIN_CONFIGS
def test_import_pretrain(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original, imported, config)
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
@FINETUNE_CONFIGS
def test_import_finetune(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original.wav2vec2, imported, config)
self._test_import_finetune(original, imported, config)
def _test_recreate(self, imported, reloaded, config):
torch.manual_seed(0)
# FeatureExtractor
x = torch.randn(3, 1024)
......@@ -194,3 +203,21 @@ class TestHFIntegration(TorchaudioTestCase):
ref, _ = imported(x)
hyp, _ = reloaded(x)
self.assertEqual(ref, hyp)
@PRETRAIN_CONFIGS
def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
@FINETUNE_CONFIGS
def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
......@@ -32,11 +32,17 @@ def _get_config(cfg):
def _build(config, original):
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
wav2vec2 = original.wav2vec2
else:
wav2vec2 = original
imported = _get_model(**config)
imported.feature_extractor.load_state_dict(original.wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(original.wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(original.wav2vec2.encoder.state_dict())
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
return imported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment