Unverified Commit 89575b56 authored by Kamil Akesbi's avatar Kamil Akesbi Committed by GitHub
Browse files

Support generating with fallback for short form audio in Whisper (#30984)



* remove is_shortform

* adapt _retrieve_max_frames_and_seek for short_form

* return bos token in short and long form

* add decoder_input_ids to short form audios

* add eos token for  short form

* handle short form token_timestamps

* no need to return scores

* add is_shortform conditions

* handle when max_new_tokens is None - short form

* handle assistant decoding

* fix

* handle return_dict_in_generate

* handle split_by_batch for encoder_attentions attribute

* handle num_beams>1

* handle num_return_sequences>1 in generate_with_fallback

* handle num_return_sequences>1 with return_dict_in_generate=True

* raise error if max_new_tokens + decoder_inputs_ids > max_target_pos

* fix

* apply review suggestions

* fix

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fix

* logits for both short form and long form

* handle if logits_processor is None

* test

* apply review changes to num_return_sequences

* add _expand_variables_for_generation

* remove short form commented section

* update comments

* uncomment num_beams line in generate_with_fallback

* update assistant decoding

* handle return_segment with short form generation

* up

* fix output format is_shortform

* overwrite beam_sample test

* update _set_return_timestamps

* apply review suggestions

* apply review suggestions

* remove seek_outputs_short_form

* fix _stack_split_outputs

* fix stack dim in _stack_split_outputs

* update tests

* fix past_key_values + beam tests

* fix

* clean _expand_variables_for_generation

* make style

* fix slow tests

* make style

* max_length condition

* make style

* add slow tests for shortform fallback

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* apply review changes

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* up

* fix slow tests

* apply review suggestions

* update test

* make style

* small fix

* fix

* fix test_new_cache_format

* fix past_key_values

* fix

* make style

* fix slow tests

* fix

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 46835ec6
...@@ -23,6 +23,8 @@ import torch ...@@ -23,6 +23,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers.cache_utils import EncoderDecoderCache
from ...generation.configuration_utils import GenerationConfig from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import ( from ...generation.logits_process import (
LogitsProcessorList, LogitsProcessorList,
...@@ -116,9 +118,10 @@ def _dynamic_time_warping(matrix: np.ndarray): ...@@ -116,9 +118,10 @@ def _dynamic_time_warping(matrix: np.ndarray):
def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) if logits_processor is not None:
if logit_processor: logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
return getattr(logit_processor, attribute_name, None) if logit_processor:
return getattr(logit_processor, attribute_name, None)
return None return None
...@@ -493,27 +496,15 @@ class WhisperGenerationMixin: ...@@ -493,27 +496,15 @@ class WhisperGenerationMixin:
) )
is_shortform = total_input_frames <= num_segment_frames is_shortform = total_input_frames <= num_segment_frames
if is_shortform:
# warn user of ignored inputs
self._maybe_warn_unused_inputs(
condition_on_prev_tokens=condition_on_prev_tokens,
temperature=temperature,
compression_ratio_threshold=compression_ratio_threshold,
logprob_threshold=logprob_threshold,
no_speech_threshold=no_speech_threshold,
total_input_frames=total_input_frames,
)
# 3. Make sure generation config is correctly set # 3. Make sure generation config is correctly set
# Make sure the generation config is correctly set depending on whether timestamps are to be returned or not # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
self._set_return_outputs( self._set_return_outputs(
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
return_token_timestamps=return_token_timestamps, return_token_timestamps=return_token_timestamps,
is_shortform=is_shortform,
logprob_threshold=logprob_threshold, logprob_threshold=logprob_threshold,
generation_config=generation_config, generation_config=generation_config,
) )
self._set_return_timestamps( timestamp_begin = self._set_return_timestamps(
return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
) )
self._set_language_and_task( self._set_language_and_task(
...@@ -554,85 +545,54 @@ class WhisperGenerationMixin: ...@@ -554,85 +545,54 @@ class WhisperGenerationMixin:
generation_config=generation_config, generation_config=generation_config,
logits_processor=logits_processor, logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token begin_index=begin_index, # begin index is index of first generated decoder token
is_shortform=is_shortform,
num_beams=kwargs.get("num_beams", 1), num_beams=kwargs.get("num_beams", 1),
device=device, device=device,
) )
# 5. If we're in shortform mode, simple generate the whole input at once and return the output # 4 Set and retrieve global generation variables
if is_shortform:
if temperature is not None:
generation_config.temperature = temperature
decoder_input_ids = kwargs.pop("decoder_input_ids", None)
if decoder_input_ids is None:
decoder_input_ids = init_tokens
if prompt_ids is not None:
decoder_input_ids = torch.cat(
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
)
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
raise ValueError(
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
f"is {max_new_tokens}. Thus, the combined length of "
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
f"so that their combined length is less than {self.config.max_target_positions}."
)
outputs = super().generate(
input_features,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
decoder_input_ids=decoder_input_ids,
**kwargs,
)
if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"):
outputs["token_timestamps"] = self._extract_token_timestamps(
outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames
)
return outputs
# 6. Else we're in longform mode which is more complex.
# We need to chunk the audio input depending on when the model generates timestamp tokens
# 6.1 Set and retrieve global longform generation variables
self._set_condition_on_prev_tokens( self._set_condition_on_prev_tokens(
condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
) )
timestamp_begin = generation_config.no_timestamps_token_id + 1
temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
temperature = temperatures[0] temperature = temperatures[0]
batch_size = input_features.shape[0]
max_frames, seek = self._retrieve_max_frames_and_seek( max_frames, seek = self._retrieve_max_frames_and_seek(
batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames batch_size=batch_size,
attention_mask=attention_mask,
total_input_frames=total_input_frames,
is_shortform=is_shortform,
) )
# 6.2 Preppare running variables, list for generation # 5 Prepare running variables, list for generation
cur_bsz = batch_size num_return_sequences = generation_config.num_return_sequences
current_segments = self._prepare_segments( (
prompt_ids=prompt_ids, batch_idx_map,
cur_bsz,
input_features,
seek,
max_frames,
init_tokens,
do_condition_on_prev_tokens,
) = self._expand_variables_for_generation(
input_features=input_features,
seek=seek,
max_frames=max_frames,
init_tokens=init_tokens,
batch_size=batch_size, batch_size=batch_size,
condition_on_prev_tokens=condition_on_prev_tokens,
generation_config=generation_config, generation_config=generation_config,
) )
batch_idx_map = list(range(batch_size)) current_segments = self._prepare_segments(
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] prompt_ids=prompt_ids,
batch_size=cur_bsz,
generation_config=generation_config,
)
# 6.2 Transcribe audio until we reach the end of all input audios # 6 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any(): while (seek < max_frames).any():
# 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
# to know which original audio is being decoded # to know which original audio is being decoded
# Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
...@@ -646,7 +606,7 @@ class WhisperGenerationMixin: ...@@ -646,7 +606,7 @@ class WhisperGenerationMixin:
time_offset = seek * time_precision / input_stride time_offset = seek * time_precision / input_stride
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
# 6.4 cut out next 30s segment from input features # 6.2 cut out next 30s segment from input features
segment_input = self._get_input_segment( segment_input = self._get_input_segment(
input_features=input_features, input_features=input_features,
seek=seek, seek=seek,
...@@ -656,10 +616,11 @@ class WhisperGenerationMixin: ...@@ -656,10 +616,11 @@ class WhisperGenerationMixin:
batch_idx_map=batch_idx_map, batch_idx_map=batch_idx_map,
) )
# 6.5 prepare decoder input ids # 6.3 prepare decoder input ids
suppress_tokens = _get_attr_from_logit_processors( suppress_tokens = _get_attr_from_logit_processors(
logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
) )
decoder_input_ids, kwargs = self._prepare_decoder_input_ids( decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
cur_bsz=cur_bsz, cur_bsz=cur_bsz,
init_tokens=init_tokens, init_tokens=init_tokens,
...@@ -669,25 +630,32 @@ class WhisperGenerationMixin: ...@@ -669,25 +630,32 @@ class WhisperGenerationMixin:
prompt_ids=prompt_ids, prompt_ids=prompt_ids,
generation_config=generation_config, generation_config=generation_config,
config=self.config, config=self.config,
device=segment_input.device, device=init_tokens.device,
suppress_tokens=suppress_tokens, suppress_tokens=suppress_tokens,
kwargs=kwargs, kwargs=kwargs,
) )
# 6.6 set max new tokens or max length # 6.4 set max new tokens or max length
self._set_max_new_tokens_and_length( self._set_max_new_tokens_and_length(
config=self.config, config=self.config,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
generation_config=generation_config, generation_config=generation_config,
) )
# 6.7 Set current `begin_index` for all logit processors # 6.5 Set current `begin_index` for all logit processors
for proc in logits_processor: if logits_processor is not None:
if hasattr(proc, "set_begin_index"): for proc in logits_processor:
proc.set_begin_index(decoder_input_ids.shape[-1]) if hasattr(proc, "set_begin_index"):
proc.set_begin_index(decoder_input_ids.shape[-1])
# 6.8 Run generate with fallback
seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback( # 6.6 Run generate with fallback
(
seek_sequences,
seek_outputs,
should_skip,
do_condition_on_prev_tokens,
model_output_type,
) = self.generate_with_fallback(
segment_input=segment_input, segment_input=segment_input,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
cur_bsz=cur_bsz, cur_bsz=cur_bsz,
...@@ -703,10 +671,11 @@ class WhisperGenerationMixin: ...@@ -703,10 +671,11 @@ class WhisperGenerationMixin:
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
return_token_timestamps=return_token_timestamps, return_token_timestamps=return_token_timestamps,
do_condition_on_prev_tokens=do_condition_on_prev_tokens, do_condition_on_prev_tokens=do_condition_on_prev_tokens,
is_shortform=is_shortform,
kwargs=kwargs, kwargs=kwargs,
) )
# 6.9 In every generated sequence, split by timestamp tokens and extract segments # 6.7 In every generated sequence, split by timestamp tokens and extract segments
for i, seek_sequence in enumerate(seek_sequences): for i, seek_sequence in enumerate(seek_sequences):
prev_i = batch_idx_map[i] prev_i = batch_idx_map[i]
...@@ -728,7 +697,11 @@ class WhisperGenerationMixin: ...@@ -728,7 +697,11 @@ class WhisperGenerationMixin:
) )
current_segments[prev_i] += segments current_segments[prev_i] += segments
seek[prev_i] += segment_offset
if is_shortform:
seek[prev_i] += max_frames[i]
else:
seek[prev_i] += segment_offset
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
...@@ -737,6 +710,7 @@ class WhisperGenerationMixin: ...@@ -737,6 +710,7 @@ class WhisperGenerationMixin:
if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment") if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
else current_segments else current_segments
) )
sequences = _pad_to_max_length( sequences = _pad_to_max_length(
final_segments, generation_config.pad_token_id, device=self.device, padding="right" final_segments, generation_config.pad_token_id, device=self.device, padding="right"
) )
...@@ -745,6 +719,42 @@ class WhisperGenerationMixin: ...@@ -745,6 +719,42 @@ class WhisperGenerationMixin:
if return_segments: if return_segments:
return {"sequences": sequences, "segments": final_segments} return {"sequences": sequences, "segments": final_segments}
if is_shortform:
# add eos token:
if generation_config.max_new_tokens is None and generation_config.max_length is None:
eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
sequences = torch.cat([sequences, eos_tokens], dim=-1)
if return_token_timestamps:
outputs = {}
outputs["sequences"] = sequences
outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
else:
outputs = sequences
if generation_config.return_dict_in_generate:
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
if num_return_sequences > 1:
if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
dict_outputs.encoder_attentions = tuple(
dict_outputs.encoder_attentions[i][::num_return_sequences]
for i in range(len(dict_outputs.encoder_attentions))
)
if (
hasattr(dict_outputs, "encoder_hidden_states")
and dict_outputs.encoder_hidden_states is not None
):
dict_outputs.encoder_hidden_states = tuple(
dict_outputs.encoder_hidden_states[i][::num_return_sequences]
for i in range(len(dict_outputs.encoder_hidden_states))
)
if return_token_timestamps:
dict_outputs["token_timestamps"] = outputs["token_timestamps"]
return dict_outputs
return outputs
return sequences return sequences
def generate_with_fallback( def generate_with_fallback(
...@@ -764,6 +774,7 @@ class WhisperGenerationMixin: ...@@ -764,6 +774,7 @@ class WhisperGenerationMixin:
synced_gpus, synced_gpus,
return_token_timestamps, return_token_timestamps,
do_condition_on_prev_tokens, do_condition_on_prev_tokens,
is_shortform,
kwargs, kwargs,
): ):
kwargs = copy.copy(kwargs) kwargs = copy.copy(kwargs)
...@@ -774,7 +785,6 @@ class WhisperGenerationMixin: ...@@ -774,7 +785,6 @@ class WhisperGenerationMixin:
needs_fallback = [False for _ in range(cur_bsz)] needs_fallback = [False for _ in range(cur_bsz)]
should_skip = [False for _ in range(cur_bsz)] should_skip = [False for _ in range(cur_bsz)]
fallback_index_map = list(range(cur_bsz)) fallback_index_map = list(range(cur_bsz))
if generation_config.no_speech_threshold is not None: if generation_config.no_speech_threshold is not None:
self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs) self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs)
...@@ -799,12 +809,15 @@ class WhisperGenerationMixin: ...@@ -799,12 +809,15 @@ class WhisperGenerationMixin:
**generate_kwargs, **generate_kwargs,
) )
model_output_type = type(seek_outputs)
# post-process sequence tokens and outputs to be in list form # post-process sequence tokens and outputs to be in list form
seek_sequences, seek_outputs = self._postprocess_outputs( seek_sequences, seek_outputs = self._postprocess_outputs(
seek_outputs=seek_outputs, seek_outputs=seek_outputs,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
return_token_timestamps=return_token_timestamps, return_token_timestamps=return_token_timestamps,
generation_config=generation_config, generation_config=generation_config,
is_shortform=is_shortform,
) )
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied # 6.7 Extract cut sequences from every sequence and check if fallback should be applied
...@@ -822,14 +835,14 @@ class WhisperGenerationMixin: ...@@ -822,14 +835,14 @@ class WhisperGenerationMixin:
# remove eos token id # remove eos token id
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1] seek_sequence = seek_sequence[:-1]
if return_token_timestamps: if return_token_timestamps and not is_shortform:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1] seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
# remove all padding tokens # remove all padding tokens
if seek_sequence[-1] == generation_config.pad_token_id: if seek_sequence[-1] == generation_config.pad_token_id:
num_paddings = (seek_sequence == generation_config.pad_token_id).sum() num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings] seek_sequence = seek_sequence[:-num_paddings]
if return_token_timestamps: if return_token_timestamps and not is_shortform:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings] seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
# check which sequences in batch need fallback & which should be skipped # check which sequences in batch need fallback & which should be skipped
...@@ -871,7 +884,7 @@ class WhisperGenerationMixin: ...@@ -871,7 +884,7 @@ class WhisperGenerationMixin:
if "decoder_attention_mask" in kwargs: if "decoder_attention_mask" in kwargs:
kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask) kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask)
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
@staticmethod @staticmethod
def _prepare_segments(prompt_ids, batch_size, generation_config): def _prepare_segments(prompt_ids, batch_size, generation_config):
...@@ -884,10 +897,14 @@ class WhisperGenerationMixin: ...@@ -884,10 +897,14 @@ class WhisperGenerationMixin:
return current_segments return current_segments
def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config): def _postprocess_outputs(
self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
):
# remove all previously passed decoder input ids # remove all previously passed decoder input ids
start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
if isinstance(seek_outputs, torch.Tensor): if isinstance(seek_outputs, torch.Tensor):
seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1] :] seek_outputs = seek_outputs[:, start_idx:]
return seek_outputs, seek_outputs return seek_outputs, seek_outputs
if return_token_timestamps and hasattr(generation_config, "alignment_heads"): if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
...@@ -895,28 +912,72 @@ class WhisperGenerationMixin: ...@@ -895,28 +912,72 @@ class WhisperGenerationMixin:
seek_outputs["token_timestamps"] = self._extract_token_timestamps( seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames seek_outputs, generation_config.alignment_heads, num_frames=num_frames
) )
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :] seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :] seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
def split_by_batch_index(values, key, batch_idx): def split_by_batch_index(values, key, batch_idx, is_shortform):
if key == "scores": if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
return [v[batch_idx].cpu() for v in values] return [v[batch_idx].cpu() for v in values]
elif key == "past_key_values": if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
# we don't save `past_key_values` as this is too costly
return None
elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values) return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
elif key == "past_key_values":
if not is_shortform:
# we don't save `past_key_values` as this is too costly for longform
return None
else:
return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values)))
return values[batch_idx].cpu() return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"] sequence_tokens = seek_outputs["sequences"]
if hasattr(seek_outputs, "past_key_values") and seek_outputs.past_key_values is not None:
if isinstance(seek_outputs["past_key_values"], EncoderDecoderCache):
seek_outputs.past_key_values = seek_outputs.past_key_values.to_legacy_cache()
seek_outputs = [ seek_outputs = [
{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
for i in range(sequence_tokens.shape[0]) for i in range(sequence_tokens.shape[0])
] ]
return sequence_tokens, seek_outputs return sequence_tokens, seek_outputs
def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
# Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
outputs = {}
for key in seek_outputs[0].keys():
if key == "sequences":
outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
outputs[key] = tuple(
torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key]))
)
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
outputs[key] = tuple(
tuple(
torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
for j in range(len(seek_outputs[0][key][0]))
)
for i in range(len(seek_outputs[0][key]))
)
if key == "past_key_values":
past_key_value_type = kwargs.get("past_key_values")
if seek_outputs[0][key] is not None:
outputs[key] = tuple(
tuple(
torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
for j in range(len(seek_outputs[0][key][0]))
)
for i in range(len(seek_outputs[0][key]))
)
if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache):
outputs[key] = past_key_value_type.from_legacy_cache(outputs[key])
else:
outputs[key] = None
return model_output_type(**outputs)
def _need_fallback( def _need_fallback(
self, self,
seek_sequence, seek_sequence,
...@@ -936,7 +997,7 @@ class WhisperGenerationMixin: ...@@ -936,7 +997,7 @@ class WhisperGenerationMixin:
needs_fallback = True needs_fallback = True
if generation_config.logprob_threshold is not None: if generation_config.logprob_threshold is not None:
if "sequences_scores" in seek_outputs[0]: if hasattr(seek_outputs[0], "sequences_scores"):
logprobs = [s["sequences_scores"] for s in seek_outputs][index] logprobs = [s["sequences_scores"] for s in seek_outputs][index]
else: else:
scores = seek_outputs[index]["scores"] scores = seek_outputs[index]["scores"]
...@@ -961,6 +1022,33 @@ class WhisperGenerationMixin: ...@@ -961,6 +1022,33 @@ class WhisperGenerationMixin:
return needs_fallback, should_skip return needs_fallback, should_skip
def _expand_variables_for_generation(
self, input_features, seek, max_frames, init_tokens, batch_size, condition_on_prev_tokens, generation_config
):
if generation_config.num_return_sequences is not None and generation_config.num_return_sequences > 1:
batch_idx_map = list(range(batch_size * generation_config.num_return_sequences))
cur_bsz = len(batch_idx_map)
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(len(batch_idx_map))]
input_features = input_features.repeat_interleave(generation_config.num_return_sequences, dim=0)
seek = seek.repeat_interleave(generation_config.num_return_sequences, dim=0)
max_frames = max_frames.repeat_interleave(generation_config.num_return_sequences, dim=0)
init_tokens = init_tokens.repeat_interleave(generation_config.num_return_sequences, dim=0)
generation_config.num_return_sequences = 1
else:
cur_bsz = batch_size
batch_idx_map = list(range(cur_bsz))
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(cur_bsz)]
return (
batch_idx_map,
cur_bsz,
input_features,
seek,
max_frames,
init_tokens,
do_condition_on_prev_tokens,
)
@staticmethod @staticmethod
def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs): def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
...@@ -1018,9 +1106,7 @@ class WhisperGenerationMixin: ...@@ -1018,9 +1106,7 @@ class WhisperGenerationMixin:
) )
@staticmethod @staticmethod
def _set_return_outputs( def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config
):
if return_dict_in_generate is None: if return_dict_in_generate is None:
return_dict_in_generate = generation_config.return_dict_in_generate return_dict_in_generate = generation_config.return_dict_in_generate
...@@ -1030,14 +1116,13 @@ class WhisperGenerationMixin: ...@@ -1030,14 +1116,13 @@ class WhisperGenerationMixin:
generation_config.output_attentions = True generation_config.output_attentions = True
generation_config.output_scores = True generation_config.output_scores = True
if not is_shortform and logprob_threshold is not None: if logprob_threshold is not None:
return_dict_in_generate = True return_dict_in_generate = True
generation_config.output_scores = True generation_config.output_scores = True
generation_config.return_dict_in_generate = return_dict_in_generate generation_config.return_dict_in_generate = return_dict_in_generate
@staticmethod def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
def _set_return_timestamps(return_timestamps, is_shortform, generation_config):
if not is_shortform: if not is_shortform:
if return_timestamps is False: if return_timestamps is False:
raise ValueError( raise ValueError(
...@@ -1057,6 +1142,15 @@ class WhisperGenerationMixin: ...@@ -1057,6 +1142,15 @@ class WhisperGenerationMixin:
generation_config.return_timestamps = return_timestamps generation_config.return_timestamps = return_timestamps
if hasattr(generation_config, "no_timestamps_token_id"):
timestamp_begin = generation_config.no_timestamps_token_id + 1
else:
# BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps
# We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop
timestamp_begin = self.config.vocab_size + 1
return timestamp_begin
@staticmethod @staticmethod
def _set_language_and_task(language, task, is_multilingual, generation_config): def _set_language_and_task(language, task, is_multilingual, generation_config):
if is_multilingual is not None: if is_multilingual is not None:
...@@ -1388,23 +1482,21 @@ class WhisperGenerationMixin: ...@@ -1388,23 +1482,21 @@ class WhisperGenerationMixin:
generation_config.condition_on_prev_tokens = condition_on_prev_tokens generation_config.condition_on_prev_tokens = condition_on_prev_tokens
@staticmethod @staticmethod
def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames): def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames, is_shortform):
if batch_size > 1 and attention_mask is None: if batch_size > 1 and not is_shortform and attention_mask is None:
raise ValueError( raise ValueError(
"When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " "When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
) )
elif batch_size > 1: elif batch_size > 1 and not is_shortform:
max_frames = attention_mask.sum(-1).cpu().to(torch.long) max_frames = attention_mask.sum(-1).cpu().to(torch.long)
seek = torch.zeros((batch_size,), dtype=torch.long) seek = torch.zeros((batch_size,), dtype=torch.long)
else: else:
max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
seek = torch.zeros((1,), dtype=torch.long) seek = torch.zeros((batch_size,), dtype=torch.long)
return max_frames, seek return max_frames, seek
def _retrieve_logit_processors( def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
self, generation_config, logits_processor, begin_index, is_shortform, num_beams, device
):
if generation_config.return_timestamps is True: if generation_config.return_timestamps is True:
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
logits_processor = ( logits_processor = (
...@@ -1431,7 +1523,7 @@ class WhisperGenerationMixin: ...@@ -1431,7 +1523,7 @@ class WhisperGenerationMixin:
) )
generation_config.begin_suppress_tokens = None generation_config.begin_suppress_tokens = None
if generation_config.no_speech_threshold is not None and not is_shortform: if generation_config.no_speech_threshold is not None:
no_speech_detector = WhisperNoSpeechDetection( no_speech_detector = WhisperNoSpeechDetection(
no_speech_token=generation_config.no_timestamps_token_id - 1, no_speech_token=generation_config.no_timestamps_token_id - 1,
begin_index=begin_index, begin_index=begin_index,
...@@ -1462,6 +1554,9 @@ class WhisperGenerationMixin: ...@@ -1462,6 +1554,9 @@ class WhisperGenerationMixin:
@staticmethod @staticmethod
def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map): def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map):
if input_features is None:
return None
segment_input = [] segment_input = []
for i in range(cur_bsz): for i in range(cur_bsz):
prev_i = batch_idx_map[i] prev_i = batch_idx_map[i]
...@@ -1493,6 +1588,11 @@ class WhisperGenerationMixin: ...@@ -1493,6 +1588,11 @@ class WhisperGenerationMixin:
suppress_tokens, suppress_tokens,
kwargs, kwargs,
): ):
if "decoder_input_ids" in kwargs:
decoder_input_ids = kwargs.pop("decoder_input_ids")
return decoder_input_ids, kwargs
cut_off_length = config.max_target_positions // 2 - 1 cut_off_length = config.max_target_positions // 2 - 1
decoder_input_ids = init_tokens[batch_idx_map] decoder_input_ids = init_tokens[batch_idx_map]
...@@ -1533,8 +1633,18 @@ class WhisperGenerationMixin: ...@@ -1533,8 +1633,18 @@ class WhisperGenerationMixin:
return decoder_input_ids, kwargs return decoder_input_ids, kwargs
@staticmethod def _set_max_new_tokens_and_length(self, config, decoder_input_ids, generation_config):
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config): max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
raise ValueError(
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
f"is {max_new_tokens}. Thus, the combined length of "
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
f"so that their combined length is less than {self.config.max_target_positions}."
)
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1) num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
# Make sure we don't get larger than `max_length` # Make sure we don't get larger than `max_length`
......
...@@ -65,6 +65,15 @@ if is_torch_available(): ...@@ -65,6 +65,15 @@ if is_torch_available():
WhisperProcessor, WhisperProcessor,
set_seed, set_seed,
) )
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
PhrasalConstraint,
)
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
...@@ -1539,6 +1548,241 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -1539,6 +1548,241 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_longform_generate_multi_batch_cond_prev(self): def test_longform_generate_multi_batch_cond_prev(self):
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
def test_beam_sample_generate_dict_output(self):
# We overwrite test_beam_sample_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = WhisperForConditionalGeneration(config).to(torch_device).eval()
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
beam_kwargs = self._get_beam_kwargs()
# With Whisper, we can only perform a beam search if the temperature is set to 0.
logits_warper_kwargs["temperature"] = 0
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"])
def test_beam_search_generate_dict_output(self):
# We overwrite test_beam_search_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
# With Whisper, we can only perform a beam search if the temperature is set to 0.
logits_process_kwargs["temperature"] = 0
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
def test_beam_search_generate_dict_outputs_use_cache(self):
# We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self._check_outputs(
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
)
def test_group_beam_search_generate_dict_output(self):
# We overwrite test_group_beam_search_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_diverse_beam_kwargs()
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
# Sample constraints
min_id = 3
max_id = model.config.vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs = self._get_constrained_beam_kwargs()
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
)
def test_custom_4d_attention_mask(self): def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32) model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
...@@ -2680,6 +2924,55 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -2680,6 +2924,55 @@ class WhisperModelIntegrationTests(unittest.TestCase):
assert decoded == EXPECTED_TEXT assert decoded == EXPECTED_TEXT
@slow
def test_whisper_shortform_single_batch_prev_cond(self):
# fmt: off
EXPECTED_TEXT = [" Folks, I spend a lot of time right over there, night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing and the most topical antilock breaks and power steering pain, Stakingly stitching, leather seating so soft, it would make JD power and her associate blush. If you were to create the luxury sedan that is my nightly model, but sometimes— you're sometimes, folks— I lurched the consciousness and the back of an abandoned school bus"]
EXPECTED_TEXT1 = [" Folks, I spend a lot of time right over there night after night after, actually. Carefully selecting for you the day's noisiest, most aerodynamic headlines, stress testing, and the most topical, anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school"]
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to(torch_device)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
dataset = ds.cast_column("audio", Audio(sampling_rate=16000))
one_audio = dataset[1]["audio"]["array"]
input_features = processor(one_audio, return_tensors="pt", sampling_rate=16_000)["input_features"]
input_features = input_features.to(device=torch_device)
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": -1.0,
}
torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
assert decoded == EXPECTED_TEXT
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.3,
"temperature": (0.0, 0.2),
"compression_ratio_threshold": 1,
"condition_on_prev_tokens": False,
"logprob_threshold": -1.0,
}
torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
assert decoded == EXPECTED_TEXT1
@slow @slow
def test_whisper_longform_single_batch_beam(self): def test_whisper_longform_single_batch_beam(self):
# fmt: off # fmt: off
...@@ -2931,6 +3224,57 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -2931,6 +3224,57 @@ class WhisperModelIntegrationTests(unittest.TestCase):
elif isinstance(EXPECTED_TEXT[i], tuple): elif isinstance(EXPECTED_TEXT[i], tuple):
assert decoded_all[i] in EXPECTED_TEXT[i] assert decoded_all[i] in EXPECTED_TEXT[i]
@slow
def test_whisper_shortform_multi_batch_hard_prev_cond(self):
# Without this set here, this test may fail if it is run with other tests (say, `test_tiny_*`). It's unclear
# why other tests may affect this tests: it seems some random operations are beyond the scene.
set_seed(0)
# fmt: off
EXPECTED_TEXT = [
' Mr. Kfilter is the apostle of the Middle Classes and we are glad to welcome his gospel.',
" Nor is Mr. Qilter's manner less interesting than his matter.",
' He tells us that at this festive season of the year, with Christmas and roce beef, looming before us, similarly drawn from eating and its results occur most readily to the mind.',
' He has grabbed those with her surfered trigger late and his work is really a great after all, and can discover it in it but little of Rocky Ithaka.',
" L'Neile's pictures are a sort of upguards and add-um paintings, and Maessin's exquisite Itals are a national as a jingo poem. Mr. Birkett Foster's landscapes smiled at one much in the same way that Mr. Carcher used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slapper in the back, before he says,",
' It is obviously unnecessary for us, to point out how luminous these criticisms are, how delicate and expression.',
' On the general principles of art and Mr. Kriltor rights with equal lucidity.',
' Painting, he tells us is of a different quality to mathematics and finish in art is adding more effect.',
]
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to(torch_device)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
num_samples = 8
audio = ds[:num_samples]["audio"]
audios = [x["array"] for x in audio]
inputs = processor(
audios,
return_tensors="pt",
sampling_rate=16_000,
)
inputs = inputs.to(device=torch_device)
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": -1.0,
}
result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True)
for i in range(num_samples):
if isinstance(EXPECTED_TEXT[i], str):
assert decoded_all[i] == EXPECTED_TEXT[i]
@slow @slow
def test_whisper_longform_no_speech_detection(self): def test_whisper_longform_no_speech_detection(self):
# fmt: off # fmt: off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment