Unverified Commit 68cc4ccd authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Pipeline ASR with LM. (#15071)



* Pipeline ASR with LM.

* Revamped into `self.decoder`.

* Fixing.

* 2nd fix.

* Update src/transformers/pipelines/__init__.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Fixing.
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1a00863e
......@@ -611,6 +611,27 @@ def pipeline(
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
)
if (
feature_extractor._processor_class
and feature_extractor._processor_class.endswith("WithLM")
and isinstance(model_name, str)
):
try:
from pyctcdecode import BeamSearchDecoderCTC
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
allow_regex = [language_model_glob, alphabet_filename]
decoder = BeamSearchDecoderCTC.load_from_hf_hub(
pretrained_model_name_or_path, allow_regex=allow_regex
)
kwargs["decoder"] = decoder
except Exception as e:
logger.warning(
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
)
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
......
......@@ -144,7 +144,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
self.is_ctc = self.model.__class__ in MODEL_FOR_CTC_MAPPING.values()
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
self.type = "seq2seq"
elif (
self.feature_extractor._processor_class
and self.feature_extractor._processor_class.endswith("WithLM")
and kwargs.get("decoder", None) is not None
):
self.decoder = kwargs["decoder"]
self.type = "ctc_with_lm"
else:
self.type = "ctc"
def __call__(
self,
......@@ -222,8 +233,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def _forward(self, model_inputs):
is_last = model_inputs.pop("is_last")
model_class = self.model.__class__
if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
if self.type == "seq2seq":
encoder = self.model.get_encoder()
# we need to pass `processed.get("attention_mask")` here since audio encoder
# attention mask length is different from expected text decoder `encoder_attention_mask` length
......@@ -232,7 +242,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
tokens = self.model.generate(
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
)
elif model_class in MODEL_FOR_CTC_MAPPING.values():
out = {"tokens": tokens}
elif self.type == "ctc_with_lm":
outputs = self.model(**model_inputs)
out = {"logits": outputs.logits}
elif self.type == "ctc":
stride = model_inputs.pop("stride", None)
outputs = self.model(**model_inputs)
tokens = outputs.logits.argmax(dim=-1)
......@@ -241,16 +256,22 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride = [stride]
apply_stride(tokens, stride)
out = {"tokens": tokens}
else:
logger.warning("This is an unknown class, treating it as CTC.")
outputs = self.model(**model_inputs)
tokens = outputs.logits.argmax(dim=-1)
return {"tokens": tokens, "is_last": is_last}
out = {"tokens": tokens}
return {"is_last": is_last, **out}
def postprocess(self, model_outputs):
skip_special_tokens = False if "CTC" in self.tokenizer.__class__.__name__ else True
if self.type == "ctc_with_lm":
logits = np.concatenate([outputs["logits"].numpy() for outputs in model_outputs], axis=1)
logits = logits.squeeze(0)
text = self.decoder.decode_beams(logits)[0][0]
else:
skip_special_tokens = self.type != "ctc"
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
tokens = tokens.squeeze(0)
recognized_string = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
return {"text": recognized_string}
text = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
return {"text": text}
......@@ -32,6 +32,7 @@ from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_pyctcdecode,
require_tf,
require_torch,
require_torchaudio,
......@@ -97,6 +98,37 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
@slow
@require_torch
@require_pyctcdecode
def test_large_model_pt_with_lm(self):
dataset = load_dataset("Narsil/asr_dummy")
filename = dataset["test"][3]["file"]
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm",
framework="pt",
)
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
output = speech_recognizer(filename)
self.assertEqual(
output,
{"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumaje"},
)
# Override back to pure CTC
speech_recognizer.type = "ctc"
output = speech_recognizer(filename)
# plumajre != plumaje
self.assertEqual(
output,
{
"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumajre"
},
)
@require_tf
def test_small_model_tf(self):
self.skipTest("Tensorflow not supported yet.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment