Unverified Commit 240e1062 authored by Ondřej Cífka's avatar Ondřej Cífka Committed by GitHub
Browse files

Fix probability computation in `WhisperNoSpeechDetection` when recomputing scores (#29248)

* Fix is_scores_logprobs in WhisperNoSpeechDetection

* Add test_whisper_longform_no_speech_detection

* Fix typo
parent bcd42c4a
...@@ -1930,6 +1930,8 @@ class WhisperNoSpeechDetection(LogitsProcessor): ...@@ -1930,6 +1930,8 @@ class WhisperNoSpeechDetection(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
is_scores_logprobs = self.is_scores_logprobs
if input_ids.shape[1] == self.begin_index: if input_ids.shape[1] == self.begin_index:
if self.start_of_trans_offset > 1: if self.start_of_trans_offset > 1:
with torch.no_grad(): with torch.no_grad():
...@@ -1937,10 +1939,11 @@ class WhisperNoSpeechDetection(LogitsProcessor): ...@@ -1937,10 +1939,11 @@ class WhisperNoSpeechDetection(LogitsProcessor):
no_speech_index = self.begin_index - self.start_of_trans_offset no_speech_index = self.begin_index - self.start_of_trans_offset
no_speech_scores = logits[:, no_speech_index] no_speech_scores = logits[:, no_speech_index]
is_scores_logprobs = False
else: else:
no_speech_scores = scores no_speech_scores = scores
if self.is_scores_logprobs: if is_scores_logprobs:
probs = no_speech_scores.exp() probs = no_speech_scores.exp()
else: else:
probs = no_speech_scores.float().softmax(dim=-1) probs = no_speech_scores.float().softmax(dim=-1)
......
...@@ -2670,6 +2670,59 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -2670,6 +2670,59 @@ class WhisperModelIntegrationTests(unittest.TestCase):
for i in range(num_samples): for i in range(num_samples):
assert decoded_all[i] == EXPECTED_TEXT[i] assert decoded_all[i] == EXPECTED_TEXT[i]
@slow
def test_whisper_longform_no_speech_detection(self):
# fmt: off
EXPECTED_TEXT = [
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories. Developing the central headline pawns, definitely maneuvering and also topical night to F6.",
" Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing",
' Ladies and gentlemen, you know, I spent a lot of time right over there raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their joke swollen teats',
' Folks, you watched this show, you know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the',
" You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui,",
' You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest.',
" Folks, if you watch this show, you know I spend most of my time right over there, carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most...",
" Folks, if you watch the show and I hope you do, I spent a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines.",
]
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to(torch_device)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
num_samples = 8
audio = ds[:num_samples]["audio"]
audios = [x["array"] for x in audio]
# Make sure the second chunk is silent
for audio in audios:
audio[15 * 16000 : 60 * 16000] = 0.0
inputs = processor(
audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
)
inputs = inputs.to(device=torch_device)
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.2,
"temperature": (0.0,),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": 0.0, # Ignore logprob, use only no-speech prob
"num_beams": 5,
}
torch.manual_seed(0)
result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
for i in range(num_samples):
assert decoded_all[i] == EXPECTED_TEXT[i]
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None: if head_mask is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment