Unverified Commit 4151fbb4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Whisper] Add sequential longform decoding (#27492)

* [Whisper] Add seq gen

* [Whisper] Add seq gen

* more debug

* Fix whisper logit processor

* Improve whisper code further

* Fix more

* more debug

* more debug

* Improve further

* Add tests

* Prep for batch size > 1

* Get batch_size>1 working

* Correct more

* Add extensive tests

* more debug

* more debug

* more debug

* add more tests

* more debug

* Apply suggestions from code review

* more debug

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* Add more examples

* add comments to explain the code better

* fix more

* add comments to explain the code better

* add comments to explain the code better

* correct

* correct

* finalize

* Apply suggestions from code review

* Apply suggestions from code review
parent b2c63c79
......@@ -1487,6 +1487,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
Examples:
``` python
......@@ -1517,29 +1518,35 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, generate_config): # support for the kwargs
def __init__(
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.begin_index = len(generate_config.forced_decoder_ids) + 2
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
# this variable is mostly just used for testing
self._detect_timestamp_from_logprob = (
_detect_timestamp_from_logprob
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
self.begin_index = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index - 1:
scores[:, :] = -float("inf")
scores[:, self.timestamp_begin] = 0
return scores
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
seq = list(input_ids[k, self.begin_index :].tolist())
sampled_tokens = input_ids[k, self.begin_index :]
seq = list(sampled_tokens.tolist())
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
......@@ -1549,8 +1556,23 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
else: # cannot be normal text tokens
scores[k, : self.eos_token_id] = -float("inf")
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
if timestamps.numel() > 0:
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
# The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
# Avoid to emit <|0.00|> again
timestamp_last = timestamps[-1] + 1
scores[k, self.timestamp_begin : timestamp_last] = -float("inf")
# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
if input_ids.shape[1] == self.begin_index:
scores[:, : self.timestamp_begin] = -float("inf")
if self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores[:, last_allowed + 1 :] = -float("inf")
......@@ -1559,7 +1581,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
scores[k, : self.timestamp_begin] = -float("inf")
return scores
......
......@@ -507,10 +507,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
):
yield item
else:
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
processed = self.feature_extractor(
inputs,
sampling_rate=self.feature_extractor.sampling_rate,
truncation=False,
padding="longest",
return_tensors="pt",
)
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
if stride is not None:
......@@ -551,8 +561,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if stride is not None:
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
else:
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
**generate_kwargs,
)
......
......@@ -16,7 +16,7 @@ import unittest
import numpy as np
import pytest
from datasets import load_dataset
from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import (
......@@ -329,16 +329,16 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{"text": " Conquered", "timestamp": (0.5, 1.2)},
{"text": " returned", "timestamp": (1.2, 1.64)},
{"text": " to", "timestamp": (1.64, 1.84)},
{"text": " its", "timestamp": (1.84, 2.02)},
{"text": " place", "timestamp": (2.02, 2.28)},
{"text": " amidst", "timestamp": (2.28, 2.78)},
{"text": " the", "timestamp": (2.78, 2.96)},
{"text": " tents.", "timestamp": (2.96, 3.48)},
'text': ' Conquered returned to its place amidst the tents.',
'chunks': [
{'text': ' Conquered', 'timestamp': (0.5, 1.2)},
{'text': ' returned', 'timestamp': (1.2, 1.64)},
{'text': ' to', 'timestamp': (1.64, 1.84)},
{'text': ' its', 'timestamp': (1.84, 2.02)},
{'text': ' place', 'timestamp': (2.02, 2.28)},
{'text': ' amidst', 'timestamp': (2.28, 2.8)},
{'text': ' the', 'timestamp': (2.8, 2.98)},
{'text': ' tents.', 'timestamp': (2.98, 3.48)},
],
},
)
......@@ -776,27 +776,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
'text': ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
'chunks': [
{'text': ' Mr.', 'timestamp': (0.38, 1.04)},
{'text': ' Quilter', 'timestamp': (1.04, 1.18)},
{'text': ' is', 'timestamp': (1.18, 1.44)},
{'text': ' the', 'timestamp': (1.44, 1.58)},
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
{'text': ' of', 'timestamp': (1.98, 2.3)},
{'text': ' the', 'timestamp': (2.3, 2.46)},
{'text': ' of', 'timestamp': (1.98, 2.32)},
{'text': ' the', 'timestamp': (2.32, 2.46)},
{'text': ' middle', 'timestamp': (2.46, 2.56)},
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
{'text': ' and', 'timestamp': (3.38, 3.52)},
{'text': ' we', 'timestamp': (3.52, 3.6)},
{'text': ' are', 'timestamp': (3.6, 3.72)},
{'text': ' classes,', 'timestamp': (2.56, 3.4)},
{'text': ' and', 'timestamp': (3.4, 3.54)},
{'text': ' we', 'timestamp': (3.54, 3.62)},
{'text': ' are', 'timestamp': (3.62, 3.72)},
{'text': ' glad', 'timestamp': (3.72, 4.0)},
{'text': ' to', 'timestamp': (4.0, 4.26)},
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
{'text': ' his', 'timestamp': (4.54, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
],
},
{'text': ' welcome', 'timestamp': (4.26, 4.56)},
{'text': ' his', 'timestamp': (4.56, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 5.84)}
]
}
)
# fmt: on
......@@ -1087,6 +1087,34 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_whisper_longform(self):
# fmt: off
EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!"""
# fmt: on
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to("cuda")
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
device="cuda:0",
)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[:1]["audio"]
result = pipe(audio)[0]["text"]
assert result == EXPECTED_RESULT
@require_torch
@slow
def test_chunking_and_timestamps(self):
......@@ -1355,7 +1383,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
out,
{
"chunks": [
{"text": "", "timestamp": (18.94, 0.0)},
{"text": "", "timestamp": (18.94, 0.02)},
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
],
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment