Unverified Commit 5d3cb760 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[Whispe] Fix pipeline after timestamp merges (#21198)

* pass return_timestamps to pre-process

* add a test to test it

* test does not need device 0

* remove failing bit

* update test
parent 5326460f
...@@ -400,6 +400,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -400,6 +400,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
" only 1 version" " only 1 version"
) )
forward_params["generate_kwargs"].update(generate_kwargs) forward_params["generate_kwargs"].update(generate_kwargs)
if return_timestamps is not None:
forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps
postprocess_params = {} postprocess_params = {}
if decoder_kwargs is not None: if decoder_kwargs is not None:
...@@ -523,6 +525,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -523,6 +525,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
generate_kwargs = {} generate_kwargs = {}
is_last = model_inputs.pop("is_last") is_last = model_inputs.pop("is_last")
return_timestamps = generate_kwargs.pop("return_timestamps", False)
if self.type == "seq2seq": if self.type == "seq2seq":
encoder = self.model.get_encoder() encoder = self.model.get_encoder()
# Consume values so we can let extra information flow freely through # Consume values so we can let extra information flow freely through
...@@ -552,7 +556,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -552,7 +556,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride = model_inputs.pop("stride", None) stride = model_inputs.pop("stride", None)
tokens = self.model.generate( tokens = self.model.generate(
input_features=model_inputs.pop("input_features"), input_features=model_inputs.pop("input_features"),
logits_processor=[WhisperTimeStampLogitsProcessor()], logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
**generate_kwargs, **generate_kwargs,
) )
out = {"tokens": tokens} out = {"tokens": tokens}
......
...@@ -291,6 +291,29 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -291,6 +291,29 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer(filename) output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
@require_torch
def test_return_timestamps_in_preprocess(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
chunk_length_s=8,
stride_length_s=1,
)
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
sample = next(iter(data))
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="en", task="transcribe")
res = pipe(sample["audio"]["array"])
self.assertEqual(res, {"text": " Conquered returned to its place amidst the tents."})
res = pipe(sample["audio"]["array"], return_timestamps=True)
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
},
)
@require_torch @require_torch
@slow @slow
def test_torch_whisper(self): def test_torch_whisper(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment