Unverified Commit 4151fbb4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Whisper] Add sequential longform decoding (#27492)

* [Whisper] Add seq gen

* [Whisper] Add seq gen

* more debug

* Fix whisper logit processor

* Improve whisper code further

* Fix more

* more debug

* more debug

* Improve further

* Add tests

* Prep for batch size > 1

* Get batch_size>1 working

* Correct more

* Add extensive tests

* more debug

* more debug

* more debug

* add more tests

* more debug

* Apply suggestions from code review

* more debug

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* Add more examples

* add comments to explain the code better

* fix more

* add comments to explain the code better

* add comments to explain the code better

* correct

* correct

* finalize

* Apply suggestions from code review

* Apply suggestions from code review
parent b2c63c79
...@@ -1487,6 +1487,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -1487,6 +1487,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1): max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future. predicting timestamps that are too far in the future.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
Examples: Examples:
``` python ``` python
...@@ -1517,29 +1518,35 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -1517,29 +1518,35 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, generate_config): # support for the kwargs def __init__(
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1 self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.begin_index = len(generate_config.forced_decoder_ids) + 2 # this variable is mostly just used for testing
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: self._detect_timestamp_from_logprob = (
self.begin_index -= 1 _detect_timestamp_from_logprob
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
self.begin_index = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps # suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf") scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index - 1:
scores[:, :] = -float("inf")
scores[:, self.timestamp_begin] = 0
return scores
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]): for k in range(input_ids.shape[0]):
seq = list(input_ids[k, self.begin_index :].tolist()) sampled_tokens = input_ids[k, self.begin_index :]
seq = list(sampled_tokens.tolist())
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
...@@ -1549,8 +1556,23 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -1549,8 +1556,23 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
else: # cannot be normal text tokens else: # cannot be normal text tokens
scores[k, : self.eos_token_id] = -float("inf") scores[k, : self.eos_token_id] = -float("inf")
# apply the `max_initial_timestamp` option timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None: if timestamps.numel() > 0:
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
# The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
# Avoid to emit <|0.00|> again
timestamp_last = timestamps[-1] + 1
scores[k, self.timestamp_begin : timestamp_last] = -float("inf")
# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index:
scores[:, : self.timestamp_begin] = -float("inf")
if self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores[:, last_allowed + 1 :] = -float("inf") scores[:, last_allowed + 1 :] = -float("inf")
...@@ -1559,7 +1581,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -1559,7 +1581,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
for k in range(input_ids.shape[0]): for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1) timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max() max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob: if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
scores[k, : self.timestamp_begin] = -float("inf") scores[k, : self.timestamp_begin] = -float("inf")
return scores return scores
......
...@@ -508,9 +508,19 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -508,9 +508,19 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
): ):
yield item yield item
else: else:
processed = self.feature_extractor( if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" processed = self.feature_extractor(
) inputs,
sampling_rate=self.feature_extractor.sampling_rate,
truncation=False,
padding="longest",
return_tensors="pt",
)
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if self.torch_dtype is not None: if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype) processed = processed.to(dtype=self.torch_dtype)
if stride is not None: if stride is not None:
...@@ -551,8 +561,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -551,8 +561,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if stride is not None: if stride is not None:
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
else:
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
tokens = self.model.generate( tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask, attention_mask=attention_mask,
**generate_kwargs, **generate_kwargs,
) )
......
...@@ -16,7 +16,7 @@ import unittest ...@@ -16,7 +16,7 @@ import unittest
import numpy as np import numpy as np
import pytest import pytest
from datasets import load_dataset from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import hf_hub_download, snapshot_download
from transformers import ( from transformers import (
...@@ -329,16 +329,16 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -329,16 +329,16 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual( self.assertEqual(
res, res,
{ {
"text": " Conquered returned to its place amidst the tents.", 'text': ' Conquered returned to its place amidst the tents.',
"chunks": [ 'chunks': [
{"text": " Conquered", "timestamp": (0.5, 1.2)}, {'text': ' Conquered', 'timestamp': (0.5, 1.2)},
{"text": " returned", "timestamp": (1.2, 1.64)}, {'text': ' returned', 'timestamp': (1.2, 1.64)},
{"text": " to", "timestamp": (1.64, 1.84)}, {'text': ' to', 'timestamp': (1.64, 1.84)},
{"text": " its", "timestamp": (1.84, 2.02)}, {'text': ' its', 'timestamp': (1.84, 2.02)},
{"text": " place", "timestamp": (2.02, 2.28)}, {'text': ' place', 'timestamp': (2.02, 2.28)},
{"text": " amidst", "timestamp": (2.28, 2.78)}, {'text': ' amidst', 'timestamp': (2.28, 2.8)},
{"text": " the", "timestamp": (2.78, 2.96)}, {'text': ' the', 'timestamp': (2.8, 2.98)},
{"text": " tents.", "timestamp": (2.96, 3.48)}, {'text': ' tents.', 'timestamp': (2.98, 3.48)},
], ],
}, },
) )
...@@ -776,27 +776,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -776,27 +776,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual( self.assertEqual(
output, output,
{ {
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", 'text': ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
"chunks": [ 'chunks': [
{'text': ' Mr.', 'timestamp': (0.0, 1.02)}, {'text': ' Mr.', 'timestamp': (0.38, 1.04)},
{'text': ' Quilter', 'timestamp': (1.02, 1.18)}, {'text': ' Quilter', 'timestamp': (1.04, 1.18)},
{'text': ' is', 'timestamp': (1.18, 1.44)}, {'text': ' is', 'timestamp': (1.18, 1.44)},
{'text': ' the', 'timestamp': (1.44, 1.58)}, {'text': ' the', 'timestamp': (1.44, 1.58)},
{'text': ' apostle', 'timestamp': (1.58, 1.98)}, {'text': ' apostle', 'timestamp': (1.58, 1.98)},
{'text': ' of', 'timestamp': (1.98, 2.3)}, {'text': ' of', 'timestamp': (1.98, 2.32)},
{'text': ' the', 'timestamp': (2.3, 2.46)}, {'text': ' the', 'timestamp': (2.32, 2.46)},
{'text': ' middle', 'timestamp': (2.46, 2.56)}, {'text': ' middle', 'timestamp': (2.46, 2.56)},
{'text': ' classes,', 'timestamp': (2.56, 3.38)}, {'text': ' classes,', 'timestamp': (2.56, 3.4)},
{'text': ' and', 'timestamp': (3.38, 3.52)}, {'text': ' and', 'timestamp': (3.4, 3.54)},
{'text': ' we', 'timestamp': (3.52, 3.6)}, {'text': ' we', 'timestamp': (3.54, 3.62)},
{'text': ' are', 'timestamp': (3.6, 3.72)}, {'text': ' are', 'timestamp': (3.62, 3.72)},
{'text': ' glad', 'timestamp': (3.72, 4.0)}, {'text': ' glad', 'timestamp': (3.72, 4.0)},
{'text': ' to', 'timestamp': (4.0, 4.26)}, {'text': ' to', 'timestamp': (4.0, 4.26)},
{'text': ' welcome', 'timestamp': (4.26, 4.54)}, {'text': ' welcome', 'timestamp': (4.26, 4.56)},
{'text': ' his', 'timestamp': (4.54, 4.92)}, {'text': ' his', 'timestamp': (4.56, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 6.66)}, {'text': ' gospel.', 'timestamp': (4.92, 5.84)}
], ]
}, }
) )
# fmt: on # fmt: on
...@@ -1087,6 +1087,34 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -1087,6 +1087,34 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s") self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_whisper_longform(self):
# fmt: off
EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!"""
# fmt: on
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to("cuda")
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
device="cuda:0",
)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[:1]["audio"]
result = pipe(audio)[0]["text"]
assert result == EXPECTED_RESULT
@require_torch @require_torch
@slow @slow
def test_chunking_and_timestamps(self): def test_chunking_and_timestamps(self):
...@@ -1355,7 +1383,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -1355,7 +1383,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
out, out,
{ {
"chunks": [ "chunks": [
{"text": "", "timestamp": (18.94, 0.0)}, {"text": "", "timestamp": (18.94, 0.02)},
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)}, {"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
], ],
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment