Unverified Commit 5dfd407b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[MMS] Scaling Speech Technology to 1,000+ Languages | Add attention adapter to Wav2Vec2 (#23813)



* add fine-tuned with adapter layer

* Add set_target_lang to tokenizer

* Implement load adapter

* add tests

* make style

* Apply suggestions from code review

* Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py

* make fix-copies

* Apply suggestions from code review

* make fix-copies

* make style again

* mkae style again

* fix doc string

* Update tests/models/wav2vec2/test_tokenization_wav2vec2.py

* Apply suggestions from code review

* fix

* Correct wav2vec2 adapter

* mkae style

* Update src/transformers/models/wav2vec2/modeling_wav2vec2.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add more nice docs

* finish

* finish

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

* all finish

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent f49a3453
...@@ -1600,7 +1600,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ...@@ -1600,7 +1600,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
) )
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config): def __init__(self, config, target_lang=None):
super().__init__(config) super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
...@@ -1618,6 +1618,13 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): ...@@ -1618,6 +1618,13 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
) )
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -1268,7 +1268,7 @@ class WavLMModel(WavLMPreTrainedModel): ...@@ -1268,7 +1268,7 @@ class WavLMModel(WavLMPreTrainedModel):
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForCTC(WavLMPreTrainedModel): class WavLMForCTC(WavLMPreTrainedModel):
def __init__(self, config): def __init__(self, config, target_lang=None):
super().__init__(config) super().__init__(config)
self.wavlm = WavLMModel(config) self.wavlm = WavLMModel(config)
...@@ -1286,6 +1286,13 @@ class WavLMForCTC(WavLMPreTrainedModel): ...@@ -1286,6 +1286,13 @@ class WavLMForCTC(WavLMPreTrainedModel):
) )
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -54,6 +54,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -54,6 +54,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from safetensors.torch import save_file as safe_save_file
from transformers import ( from transformers import (
Wav2Vec2FeatureExtractor, Wav2Vec2FeatureExtractor,
...@@ -67,6 +68,8 @@ if is_torch_available(): ...@@ -67,6 +68,8 @@ if is_torch_available():
Wav2Vec2Processor, Wav2Vec2Processor,
) )
from transformers.models.wav2vec2.modeling_wav2vec2 import ( from transformers.models.wav2vec2.modeling_wav2vec2 import (
WAV2VEC2_ADAPTER_PT_FILE,
WAV2VEC2_ADAPTER_SAFE_FILE,
Wav2Vec2GumbelVectorQuantizer, Wav2Vec2GumbelVectorQuantizer,
_compute_mask_indices, _compute_mask_indices,
_sample_negative_indices, _sample_negative_indices,
...@@ -290,6 +293,17 @@ class Wav2Vec2ModelTester: ...@@ -290,6 +293,17 @@ class Wav2Vec2ModelTester:
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size), (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
) )
def create_and_check_model_with_attn_adapter(self, config, input_values, attention_mask):
config.adapter_attn_dim = 16
model = Wav2Vec2ForCTC(config=config)
self.parent.assertIsNotNone(model._adapters)
model.to(torch_device)
model.eval()
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size))
def create_and_check_batch_inference(self, config, input_values, *args): def create_and_check_batch_inference(self, config, input_values, *args):
# test does not pass for models making use of `group_norm` # test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227 # check: https://github.com/pytorch/fairseq/issues/3227
...@@ -844,6 +858,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -844,6 +858,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs) self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
def test_model_with_attn_adapter(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_attn_adapter(*config_and_inputs)
def test_batched_inference(self): def test_batched_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs) self.model_tester.create_and_check_batch_inference(*config_and_inputs)
...@@ -1098,6 +1116,85 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -1098,6 +1116,85 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
def test_load_attn_adapter(self):
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
)
def get_logits(model, input_features):
model = model.to(torch_device)
batch = processor(
input_features,
padding=True,
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
)
with torch.no_grad():
logits = model(
input_values=batch["input_values"].to(torch_device),
attention_mask=batch["attention_mask"].to(torch_device),
).logits
return logits
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2", adapter_attn_dim=16)
with tempfile.TemporaryDirectory() as tempdir:
model.save_pretrained(tempdir)
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
logits = get_logits(model, input_features)
adapter_weights = model._adapters
# save safe weights
safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng"))
safe_save_file(adapter_weights, safe_filepath, metadata={"format": "pt"})
model.load_adapter("eng")
model.load_adapter("eng", use_safetensors=True)
with self.assertRaises(OSError):
model.load_adapter("eng", use_safetensors=False)
with self.assertRaises(Exception):
model.load_adapter("ita", use_safetensors=True)
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
with tempfile.TemporaryDirectory() as tempdir:
model.save_pretrained(tempdir)
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
logits = get_logits(model, input_features)
adapter_weights = model._adapters
# save pt weights
pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng"))
torch.save(adapter_weights, pt_filepath)
model.load_adapter("eng")
model.load_adapter("eng", use_safetensors=False)
with self.assertRaises(OSError):
model.load_adapter("eng", use_safetensors=True)
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
logits = get_logits(model, input_features)
model.load_adapter("eng")
model.load_adapter("eng", use_safetensors=False)
model.load_adapter("eng", use_safetensors=True)
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
...@@ -1768,3 +1865,45 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1768,3 +1865,45 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# TODO: update the tolerance after the CI moves to torch 1.10 # TODO: update the tolerance after the CI moves to torch 1.10
self.assertAlmostEqual(outputs.loss.item(), 17.7963, 2) self.assertAlmostEqual(outputs.loss.item(), 17.7963, 2)
@require_torchaudio
def test_inference_mms_1b_all(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
LANG_MAP = {"it": "ita", "es": "spa", "fr": "fra", "en": "eng"}
def run_model(lang):
ds = load_dataset("common_voice", lang, split="test", streaming=True)
sample = next(iter(ds))
wav2vec2_lang = LANG_MAP[lang]
model.load_adapter(wav2vec2_lang)
processor.tokenizer.set_target_lang(wav2vec2_lang)
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
inputs = processor(resampled_audio, sampling_rate=16_000, return_tensors="pt")
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
return transcription
TRANSCRIPTIONS = {
"it": "mi hanno fatto un'offerta che non potevo proprio rifiutare",
"es": "bien y qué regalo vas a abrir primero",
"fr": "un vrai travail intéressant va enfin être mené sur ce sujet",
"en": "twas the time of day and olof spen slept during the summer",
}
for lang in LANG_MAP.keys():
assert run_model(lang) == TRANSCRIPTIONS[lang]
...@@ -772,3 +772,48 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -772,3 +772,48 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
output = tokenizer.convert_tokens_to_string(tokens) output = tokenizer.convert_tokens_to_string(tokens)
self.assertIsInstance(output["text"], str) self.assertIsInstance(output["text"], str)
def test_nested_vocab(self):
eng_vocab = {"a": 7, "b": 8}
spa_vocab = {"a": 23, "c": 88}
ita_vocab = {"a": 6, "d": 9}
nested_vocab = {"eng": eng_vocab, "spa": spa_vocab, "ita": ita_vocab}
def check_tokenizer(tokenizer, check_ita_first=False):
if check_ita_first:
self.assertEqual(tokenizer.decode([6, 9, 9]), "ad")
self.assertEqual(tokenizer.encoder, ita_vocab)
tokenizer.set_target_lang("eng")
self.assertEqual(tokenizer.encoder, eng_vocab)
self.assertEqual(tokenizer.decode([7, 8, 7]), "aba")
tokenizer.set_target_lang("spa")
self.assertEqual(tokenizer.decode([23, 88, 23]), "aca")
self.assertEqual(tokenizer.encoder, spa_vocab)
tokenizer.set_target_lang("eng")
self.assertEqual(tokenizer.encoder, eng_vocab)
self.assertEqual(tokenizer.decode([7, 7, 8]), "ab")
tokenizer.set_target_lang("ita")
self.assertEqual(tokenizer.decode([6, 9, 9]), "ad")
self.assertEqual(tokenizer.encoder, ita_vocab)
with tempfile.TemporaryDirectory() as tempdir:
tempfile_path = os.path.join(tempdir, "vocab.json")
with open(tempfile_path, "w") as temp_file:
json.dump(nested_vocab, temp_file)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir, target_lang="eng")
check_tokenizer(tokenizer)
with tempfile.TemporaryDirectory() as tempdir:
# should have saved target lang as "ita" since it was last one
tokenizer.save_pretrained(tempdir)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir)
self.assertEqual(tokenizer.target_lang, "ita")
check_tokenizer(tokenizer, check_ita_first=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment