Unverified Commit e9cab8f8 authored by moto's avatar moto Committed by GitHub
Browse files

Fix HF model integration (#1781)

* Fix HF model integration

Previously, when testing wav2vec models from HF transformers, all the model were
instantiated as `Wav2Vec2ForCTC` class, while some of them were supposed to be
`Wav2Vec2Model`.

Fixing this revealed that model importer cannot correctly handle `Wav2Vec2Model` import.

This PR fixes these issues.
parent 1b4b82e0
...@@ -39,14 +39,14 @@ HF_LARGE_LV60_SELF_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60-sel ...@@ -39,14 +39,14 @@ HF_LARGE_LV60_SELF_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60-sel
HF_LARGE_XLSR_DE = _load_config('facebook', 'wav2vec2-large-xlsr-53-german') HF_LARGE_XLSR_DE = _load_config('facebook', 'wav2vec2-large-xlsr-53-german')
# Config and corresponding factory functions # Config and corresponding factory functions
HF_CONFIGS = parameterized.expand([ PRETRAIN_CONFIGS = parameterized.expand([
# pretrained
(HF_BASE, wav2vec2_base), (HF_BASE, wav2vec2_base),
(HF_LARGE, wav2vec2_large), (HF_LARGE, wav2vec2_large),
(HF_LARGE_LV60, wav2vec2_large_lv60k), (HF_LARGE_LV60, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_53, wav2vec2_large_lv60k), (HF_LARGE_XLSR_53, wav2vec2_large_lv60k),
(HF_BASE_10K_VOXPOPULI, wav2vec2_base), (HF_BASE_10K_VOXPOPULI, wav2vec2_base),
# finetuned ], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([
(HF_BASE_960H, wav2vec2_base), (HF_BASE_960H, wav2vec2_base),
(HF_LARGE_960H, wav2vec2_large), (HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k), (HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
...@@ -72,34 +72,34 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -72,34 +72,34 @@ class TestHFIntegration(TorchaudioTestCase):
# the actual tests are started. # the actual tests are started.
from transformers.models.wav2vec2 import ( from transformers.models.wav2vec2 import (
Wav2Vec2Config, Wav2Vec2Config,
Wav2Vec2Model,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
) )
return Wav2Vec2ForCTC(Wav2Vec2Config(**config)) if config['architectures'] == ['Wav2Vec2Model']:
return Wav2Vec2Model(Wav2Vec2Config(**config))
@HF_CONFIGS if config['architectures'] == ['Wav2Vec2ForCTC']:
def test_import(self, config, _): return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
"""wav2vec2 models from HF transformers can be imported and yields the same results""" raise ValueError(f'Unexpected arch: {config["architectures"]}')
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
def _test_import_pretrain(self, original, imported, config, ):
torch.manual_seed(0) torch.manual_seed(0)
# FeatureExtractor # FeatureExtractor
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
ref = original.wav2vec2.feature_extractor(x).transpose(1, 2) ref = original.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.feature_extractor(x, None) hyp, _ = imported.feature_extractor(x, None)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Feature projection # Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1]) x = torch.randn(3, 10, config['conv_dim'][-1])
ref = original.wav2vec2.feature_projection(x)[0] ref = original.feature_projection(x)[0]
hyp = imported.encoder.feature_projection(x) hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Convolutional Positional Encoder # Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size']) x = torch.randn(3, 256, config['hidden_size'])
ref = original.wav2vec2.encoder.pos_conv_embed(x) ref = original.encoder.pos_conv_embed(x)
hyp = imported.encoder.transformer.pos_conv_embed(x) hyp = imported.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Encoder Transformer Layer # Encoder Transformer Layer
for original_, imported_ in zip(original.wav2vec2.encoder.layers, imported.encoder.transformer.layers): for original_, imported_ in zip(original.encoder.layers, imported.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"] b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e) x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l) mask = torch.randn(b, 1, l, l)
...@@ -110,9 +110,11 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -110,9 +110,11 @@ class TestHFIntegration(TorchaudioTestCase):
# The whole Encoder Transformer # The whole Encoder Transformer
b, l, e = 16, 3, config["hidden_size"] b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e) x = torch.randn(b, l, e)
ref = original.wav2vec2.encoder(x).last_hidden_state ref = original.encoder(x).last_hidden_state
hyp = imported.encoder.transformer(x) hyp = imported.encoder.transformer(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
def _test_import_finetune(self, original, imported, config):
# Readout # Readout
x = torch.randn(3, 10, config["hidden_size"]) x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x) ref = original.lm_head(x)
...@@ -142,15 +144,22 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -142,15 +144,22 @@ class TestHFIntegration(TorchaudioTestCase):
for i, l in enumerate(output_lengths): for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...]) self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])
@HF_CONFIGS @PRETRAIN_CONFIGS
def test_recreate(self, config, factory_func): def test_import_pretrain(self, config, _):
"""Imported models can be recreated via a factory function without Hugging Face transformers.""" """wav2vec2 models from HF transformers can be imported and yields the same results"""
imported = import_huggingface_model(self._get_model(config)).eval() original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original, imported, config)
reloaded = factory_func(num_out=imported.encoder.readout.out_features) @FINETUNE_CONFIGS
reloaded.load_state_dict(imported.state_dict()) def test_import_finetune(self, config, _):
reloaded.eval() """wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original.wav2vec2, imported, config)
self._test_import_finetune(original, imported, config)
def _test_recreate(self, imported, reloaded, config):
torch.manual_seed(0) torch.manual_seed(0)
# FeatureExtractor # FeatureExtractor
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
...@@ -194,3 +203,21 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -194,3 +203,21 @@ class TestHFIntegration(TorchaudioTestCase):
ref, _ = imported(x) ref, _ = imported(x)
hyp, _ = reloaded(x) hyp, _ = reloaded(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
@PRETRAIN_CONFIGS
def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
@FINETUNE_CONFIGS
def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
...@@ -32,11 +32,17 @@ def _get_config(cfg): ...@@ -32,11 +32,17 @@ def _get_config(cfg):
def _build(config, original): def _build(config, original):
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
wav2vec2 = original.wav2vec2
else:
wav2vec2 = original
imported = _get_model(**config) imported = _get_model(**config)
imported.feature_extractor.load_state_dict(original.wav2vec2.feature_extractor.state_dict()) imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(original.wav2vec2.feature_projection.state_dict()) imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(original.wav2vec2.encoder.state_dict()) imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
imported.encoder.readout.load_state_dict(original.lm_head.state_dict()) if original.__class__.__name__ == 'Wav2Vec2ForCTC':
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
return imported return imported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment