Unverified Commit 4151fbb4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Whisper] Add sequential longform decoding (#27492)

* [Whisper] Add seq gen

* [Whisper] Add seq gen

* more debug

* Fix whisper logit processor

* Improve whisper code further

* Fix more

* more debug

* more debug

* Improve further

* Add tests

* Prep for batch size > 1

* Get batch_size>1 working

* Correct more

* Add extensive tests

* more debug

* more debug

* more debug

* add more tests

* more debug

* Apply suggestions from code review

* more debug

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* Add more examples

* add comments to explain the code better

* fix more

* add comments to explain the code better

* add comments to explain the code better

* correct

* correct

* finalize

* Apply suggestions from code review

* Apply suggestions from code review
parent b2c63c79
......@@ -1487,6 +1487,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
Examples:
``` python
......@@ -1517,29 +1518,35 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, generate_config): # support for the kwargs
def __init__(
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.begin_index = len(generate_config.forced_decoder_ids) + 2
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
# this variable is mostly just used for testing
self._detect_timestamp_from_logprob = (
_detect_timestamp_from_logprob
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
self.begin_index = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index - 1:
scores[:, :] = -float("inf")
scores[:, self.timestamp_begin] = 0
return scores
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
seq = list(input_ids[k, self.begin_index :].tolist())
sampled_tokens = input_ids[k, self.begin_index :]
seq = list(sampled_tokens.tolist())
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
......@@ -1549,8 +1556,23 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
else: # cannot be normal text tokens
scores[k, : self.eos_token_id] = -float("inf")
# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
if timestamps.numel() > 0:
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
# The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
# Avoid to emit <|0.00|> again
timestamp_last = timestamps[-1] + 1
scores[k, self.timestamp_begin : timestamp_last] = -float("inf")
# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index:
scores[:, : self.timestamp_begin] = -float("inf")
if self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores[:, last_allowed + 1 :] = -float("inf")
......@@ -1559,7 +1581,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
scores[k, : self.timestamp_begin] = -float("inf")
return scores
......
......@@ -15,6 +15,7 @@
""" PyTorch Whisper model."""
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
......@@ -1111,6 +1112,13 @@ class WhisperEncoder(WhisperPreTrainedModel):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
if input_features.shape[-1] != expected_seq_length:
raise ValueError(
f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -1723,7 +1731,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def generate(
self,
inputs: Optional[torch.Tensor] = None,
input_features: Optional[torch.Tensor] = None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
......@@ -1734,12 +1742,16 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
language=None,
is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None,
return_token_timestamps=None,
num_segment_frames: Optional[int] = None,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
attention_mask: Optional[torch.Tensor] = None,
time_precision: int = 0.02,
return_dict_in_generate: Optional[bool] = None,
**kwargs,
):
"""
Generates sequences of token ids for models with a language modeling head.
Transcribes or translates passed mel input features to a sequence of token ids.
<Tip warning={true}>
......@@ -1801,46 +1813,162 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
words.
return_segments (`bool`, *optional*, defaults to `False`):
Whether to additionally return a list of all segments. Note that this option can only be enabled
when doing long-form transcription.
attention_mask (`torch.Tensor`, *optional*):
`attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
time_precision (`int`, *optional*, defaults to 0.02):
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
for 20 ms.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
`return_segments` is set True. In this case the generation outputs of each segment is added to each
segment.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
else only the generated output sequence ids are returned.
Example:
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset, Audio
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> model.cuda()
>>> # load audios > 30 seconds
>>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
>>> # resample to 16kHz
>>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
>>> # take first 8 audios and retrieve array
>>> audio = ds[:8]["audio"]
>>> audio = [x["array"] for x in audio]
>>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
>>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
>>> inputs = inputs.to("cuda", torch.float32)
>>> # transcribe audio to ids
>>> generated_ids = model.generate(**inputs)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> transcription[0]
' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!'
```
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
"""
if "inputs" in kwargs:
input_features = kwargs.pop("inputs")
warnings.warn(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
if generation_config is None:
generation_config = self.generation_config
if return_timestamps is not None:
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
if num_segment_frames is None:
num_segment_frames = input_stride * self.config.max_source_positions
# 1. Check whether we're in shortform or longform mode
if input_features is not None:
total_input_frames = input_features.shape[-1]
elif "encoder_outputs" in kwargs:
encoder_outputs_shape = (
kwargs["encoder_outputs"][0].shape
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
else kwargs["encoder_outputs"].shape
)
total_input_frames = encoder_outputs_shape[1] * input_stride
else:
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
is_shortform = total_input_frames <= num_segment_frames
# 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
if return_timestamps is True:
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = return_timestamps
elif not is_shortform:
if return_timestamps is False:
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
)
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the generation config to have `no_timestamps_token_id` correctly. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
"or make sure to pass no more than 3000 mel input features."
)
logger.info("Setting `return_timestamps=True` for long-form generation.")
generation_config.return_timestamps = True
else:
generation_config.return_timestamps = False
# 3. Make sure to correctly set language-related parameters
if is_multilingual is not None:
if not hasattr(generation_config, "is_multilingual"):
raise ValueError(
......@@ -1875,8 +2003,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
)
generation_config.task = task
# 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps`
forced_decoder_ids = None
# Legacy code for backward compatibility
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
......@@ -1961,12 +2089,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
generation_config.forced_decoder_ids = forced_decoder_ids
if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
if return_token_timestamps:
kwargs["output_attentions"] = True
kwargs["return_dict_in_generate"] = True
return_dict_in_generate = True
if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
......@@ -1979,23 +2104,267 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
if kwargs.get("num_frames") is not None:
generation_config.num_frames = kwargs.pop("num_frames")
outputs = super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)
if generation_config.return_timestamps is True:
last_forced_decoder_ids = (
generation_config.forced_decoder_ids[-1][-1]
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids
else None
)
if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id:
# remove no_timestamp to be forcefully generated if we want to return timestamps
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
forced_decoder_ids = generation_config.forced_decoder_ids[:-1]
# Make sure that if list is empty we set it to None
generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids
timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
logits_processor = (
timestamp_processor if logits_processor is None else timestamp_processor + logits_processor
)
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
outputs["token_timestamps"] = self._extract_token_timestamps(
outputs, generation_config.alignment_heads, num_frames=num_frames
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
if is_shortform:
outputs = super().generate(
input_features,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return outputs
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
outputs["token_timestamps"] = self._extract_token_timestamps(
outputs, generation_config.alignment_heads, num_frames=num_frames
)
return outputs
# 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated
# timestamp tokens
# 6.1 Set running parameters for while loop
if not return_segments and return_dict_in_generate:
raise ValueError(
"Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`"
)
# if input is longer than 30 seconds we default to long-form generation
timestamp_begin = self.generation_config.no_timestamps_token_id + 1
# input stride is mel frames per encoder output vector which is the product of all conv strides
batch_size = input_features.shape[0]
if batch_size > 1 and attention_mask is None:
raise ValueError(
"When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
)
elif batch_size > 1:
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
seek = torch.zeros((batch_size,), dtype=torch.long)
else:
max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames
seek = torch.zeros((1,), dtype=torch.long)
current_segments = [[] for _ in range(batch_size)]
cur_to_prev_index_map = list(range(batch_size))
# batch size can decrease during the run
cur_bsz = prev_bsz = batch_size
# 6.2 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any():
prev_bsz = cur_bsz
# 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
# to know which original audio is being decoded
new_cur_to_prev_index_map = []
for i in range(prev_bsz):
prev_i = cur_to_prev_index_map[i]
if seek[prev_i] >= max_frames[prev_i]:
cut_index = i + (cur_bsz - prev_bsz)
cur_bsz -= 1
input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
else:
# cut out index that goes away
new_cur_to_prev_index_map.append(prev_i)
# 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
cur_to_prev_index_map = new_cur_to_prev_index_map
time_offset = seek * time_precision / input_stride
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
# 6.5 Make sure that all inputs are padded to the same input length
segment_input = []
for i in range(cur_bsz):
prev_i = cur_to_prev_index_map[i]
segment_input_slice = input_features[
i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]
]
if segment_input_slice.shape[-1] < num_segment_frames:
# pad to 3000 if necessary
segment_input_slice = F.pad(
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
)
segment_input.append(segment_input_slice)
segment_input = torch.cat(segment_input, dim=0)
# 6.6 Batch generate current chunk
seek_outputs = super().generate(
segment_input,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
)
if return_dict_in_generate:
seek_sequences = seek_outputs["sequences"]
seek_outputs = [
{k: v[i] for k, v in seek_outputs.items()}
for i in range(next(iter(seek_outputs.values())).size(0))
]
else:
seek_sequences = seek_outputs
# 6.7 Loop over each decoded audio individually as each decoding can be of a different length
for i, seek_sequence in enumerate(seek_sequences):
prev_i = cur_to_prev_index_map[i]
# make sure we cut a predicted EOS token if we are not finished with the generation yet
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1]
# remove all padding tokens
if seek_sequence[-1] == self.generation_config.pad_token_id:
num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings]
segments, segment_offset = self._retrieve_segment(
seek_sequence=seek_sequence,
seek_outputs=seek_outputs,
time_offset=time_offset,
timestamp_begin=timestamp_begin,
seek_num_frames=seek_num_frames,
cur_bsz=cur_bsz,
time_precision=time_precision,
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
)
current_segments[prev_i] += segments
seek[prev_i] += segment_offset
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
sequences = []
max_total_length = 0
for current_segment_list in current_segments:
sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1))
max_total_length = max(max_total_length, len(sequences[-1]))
for i in range(batch_size):
sequences[i] = F.pad(
sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id
)
sequences = torch.stack(sequences, dim=0)
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if return_segments:
return {"sequences": sequences, "segments": current_segments}
return sequences
@staticmethod
def _retrieve_segment(
seek_sequence,
seek_outputs,
time_offset,
timestamp_begin,
seek_num_frames,
cur_bsz,
time_precision,
input_stride,
prev_idx,
idx,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == cur_bsz * [[False, True]]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
if len(timestamp_segment_indices) > 0:
# if the output contains two consecutive timestamp tokens
slices = timestamp_segment_indices.tolist()
segments = []
if single_timestamp_ending:
slices.append(len(seek_sequence))
last_slice = 0
# Add each segment to list of all segments
for current_slice in slices:
sliced_tokens = seek_sequence[last_slice + 1 : current_slice + 1]
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
"tokens": sliced_tokens,
"result": seek_outputs[idx],
}
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
segment_offset = seek_num_frames[prev_idx]
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
# here we throw away all predictions after the last predicted "end of segment"
# since we are cutting right in the middle of an audio
last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin
segment_offset = last_timestamp_pos * input_stride
else:
# If whisper does not predict any "end of segment" token, then
# the whole decoding is considered a segment and we add it to the list of segments
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
last_timestamp_pos = seek_num_frames[prev_idx]
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
segments = [
{
"start": time_offset[prev_idx],
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
"tokens": seek_sequence,
"result": seek_outputs[idx],
}
]
segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset
def prepare_inputs_for_generation(
self,
......@@ -2229,7 +2598,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
>>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
>>> # decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
```"""
......
......@@ -508,9 +508,19 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
):
yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
processed = self.feature_extractor(
inputs,
sampling_rate=self.feature_extractor.sampling_rate,
truncation=False,
padding="longest",
return_tensors="pt",
)
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
if stride is not None:
......@@ -551,8 +561,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if stride is not None:
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
else:
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
**generate_kwargs,
)
......
......@@ -17,6 +17,7 @@
import copy
import inspect
import os
import random
import tempfile
import time
import unittest
......@@ -47,7 +48,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_datasets_available():
import datasets
from datasets import load_dataset
from datasets import Audio, load_dataset
if is_torch_available():
import torch
......@@ -61,8 +62,81 @@ if is_torch_available():
WhisperProcessor,
set_seed,
)
from transformers.generation.logits_process import LogitsProcessor
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
class DummyTimestampLogitProcessor(LogitsProcessor):
"""This processor fakes the correct timestamps tokens pattern [TOK_1] [TOK_2] ... [TOK_N] [TIME_STAMP_TOK_1] [TIME_STAMP_TOK_2] [TOK_N+1] ..."""
def __init__(
self, timestamp_begin, vocab_size, batch_size, max_length, min_space=3, seed=0, is_length_ascending=True
):
self.timestamp_begin = timestamp_begin
self.vocab_size = vocab_size
self.min_space_between_timestamps = min_space
self.timestamp_tokens = torch.arange(self.timestamp_begin, self.vocab_size)
self.timestamp_tokens.to(torch_device)
self.is_length_ascending = is_length_ascending
self.no_time_stamp_counter = batch_size * [0]
self.prev_highest_timestamp = batch_size * [0]
self.batch_size = batch_size
self.max_length = max_length
self.count = 0
self.let_pass = [[] for _ in range(batch_size)]
for k in range(batch_size):
random.seed(seed + k)
for _ in range(10000):
self.let_pass[k].append(random.randint(1, 10) <= 3)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# we don't want to randomely sample timestamp tokens
if input_ids.shape[-1] > 1:
scores[:, self.timestamp_begin :] = -float("inf")
self.no_time_stamp_counter = [x + 1 for x in self.no_time_stamp_counter]
for k in range(input_ids.shape[0]):
# make sure to use correct index if a batch was removed
if self.is_length_ascending and input_ids.shape[0] < self.batch_size:
prev_k = k + self.batch_size - input_ids.shape[0]
else:
prev_k = k
if input_ids[k, -1] == self.timestamp_begin:
self.no_time_stamp_counter[prev_k] = 0
can_produce = self.no_time_stamp_counter[prev_k] > self.min_space_between_timestamps
must_produce = (
input_ids[k][2:].le(self.timestamp_begin).all() and input_ids.shape[-1] == self.max_length - 1
)
# produce timestamp with 30%
if (can_produce and self.let_pass[prev_k][self.count]) or must_produce:
self.no_time_stamp_counter[prev_k] = 0
self.prev_highest_timestamp[prev_k] = max(input_ids[k].max() + 1, self.timestamp_tokens[0].item())
# force a timestamp
scores[k, :] = -float("inf")
scores[k, self.prev_highest_timestamp[prev_k]] = 10.0
if (
input_ids.shape[-1] > 3
and input_ids[k, -1].item() in self.timestamp_tokens
and input_ids[k, -2].item() not in self.timestamp_tokens
):
# force the same as before
scores[k, :] = -float("inf")
scores[k, input_ids[k, -1].item()] = 10.0
self.count += 1
if torch.isinf(scores).all():
raise ValueError("Dummy logit processor is incorrectly set up. Scores should not be all inf.")
return scores
if is_flax_available():
import jax.numpy as jnp
......@@ -1237,6 +1311,133 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.generate(input_features, max_new_tokens=1, prompt_ids=prompt_ids)
def test_longform_generate_single_batch(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
input_features = input_dict["input_features"]
# len = 250 with num_input_frames = 60
long_input_features = torch.cat([input_features.repeat(1, 1, 4), input_features[:, :, :10]], dim=-1)
# force bsz=1
long_input_features = long_input_features[:1]
vocab_size = model.config.vocab_size
batch_size = 1
num_timestamp_tokens = 20
max_length = 16
logits_processor = [
DummyTimestampLogitProcessor(
vocab_size - num_timestamp_tokens,
vocab_size,
batch_size=batch_size,
max_length=max_length,
min_space=4,
)
]
# each chunk should not be longer than 10
model.generation_config.max_length = max_length
# if input features are long can't set return_timestamps to False
with self.assertRaises(ValueError):
_ = model.generate(long_input_features, logits_processor=logits_processor, return_timestamps=False)
# if input features are long need to set generation config
with self.assertRaises(ValueError):
_ = model.generate(long_input_features, logits_processor=logits_processor)
timestamp_begin = vocab_size - num_timestamp_tokens
model.generation_config.no_timestamps_token_id = timestamp_begin - 1
model.generation_config.eos_token_id = None
model.generation_config._detect_timestamp_from_logprob = False
# make sure that we only have the same begin token
model.generation_config.max_initial_timestamp_index = 0
outputs = model.generate(long_input_features, logits_processor=logits_processor, return_segments=True)
segments = outputs["segments"][0]
for i, segment in enumerate(segments):
assert segment["start"] <= segment["end"], "start has to be smaller equal end"
assert (
segment["tokens"][0] == model.generation_config.decoder_start_token_id
or segment["tokens"][0] >= timestamp_begin
), "First segment token should be a timestamp token"
assert any(
s > timestamp_begin for s in segment["tokens"][1:]
), f"At least one segment token should be a timestamp token, but not first., {segment['tokens']}"
assert (
segment["tokens"].shape[-1] <= max_length
), "make sure that no segment is larger than max generation length"
def test_longform_generate_multi_batch(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
input_features = input_dict["input_features"].to(torch_device)
# len = 250 with num_input_frames = 60
long_input_features = torch.cat([input_features.repeat(1, 1, 4), input_features[:, :, :10]], dim=-1)
long_input_features[:1, :, :200]
input_features_2 = long_input_features[1:]
attention_mask = torch.ones(
(2, long_input_features.shape[-1]), dtype=input_features.dtype, device=input_features.device
)
attention_mask[0, 200:] = 0
# force bsz=1
vocab_size = model.config.vocab_size
batch_size = 1
num_timestamp_tokens = 20
max_length = 16
timestamp_begin = vocab_size - num_timestamp_tokens
model.generation_config.no_timestamps_token_id = timestamp_begin - 1
model.generation_config.eos_token_id = None
model.generation_config._detect_timestamp_from_logprob = False
# make sure that we only have the same begin token
model.generation_config.max_initial_timestamp_index = 0
logits_processor = [
DummyTimestampLogitProcessor(
vocab_size - num_timestamp_tokens,
vocab_size,
batch_size=batch_size,
max_length=max_length,
min_space=4,
seed=1,
)
]
outputs_2 = model.generate(input_features_2, logits_processor=logits_processor, return_segments=True)
tokens_2 = outputs_2["sequences"][0]
segments_2 = outputs_2["segments"][0]
batch_size = 2
logits_processor = [
DummyTimestampLogitProcessor(
vocab_size - num_timestamp_tokens,
vocab_size,
batch_size=batch_size,
max_length=max_length,
min_space=4,
seed=0,
)
]
outputs = model.generate(
long_input_features, attention_mask=attention_mask, logits_processor=logits_processor, return_segments=True
)
tokens = outputs["sequences"][1]
segments = outputs["segments"][1]
assert tokens_2.tolist() == tokens.tolist()
for seg1, seg2 in zip(segments_2, segments):
assert seg1["start"] == seg2["start"]
assert seg1["end"] == seg2["end"]
assert seg1["tokens"].tolist() == seg2["tokens"].tolist()
@require_torch
@require_torchaudio
......@@ -1831,6 +2032,125 @@ class WhisperModelIntegrationTests(unittest.TestCase):
]
assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
@slow
def test_whisper_longform_single_batch(self):
# fmt: off
EXPECTED_TEXT = [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter\'s manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton\'s work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell\'s pictures are a sort of up-gards and atom paintings, and Mason\'s exquisite idles are as national as a jingo poem. Mr. Birk at Foster\'s landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampoo or a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes the customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mantelboard. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon\'s body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered his muscles into complete relaxation. Oli\'s heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, The thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I\'m here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you\'re being a fool. out, through his silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry, and victory to the stronger. man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. had died before during the 20s and death during the last round was in some ways easier than defeat. Breathing deeply, Breon\'s softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent\'s face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that\'s rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with says he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn\'t work too hard, said Shaggy. He doesn\'t work at all. In fact, there\'s nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we\'ve turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The middle forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I\'m quite sure he didn\'t. That\'s funny, remarked Betsy thoughtfully. I don\'t believe Anne knew any magic, or she\'d have worked it before. I do not know, confess Shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Virgato used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgato\'s discarded ruby crown and holding in his hand to scepter which reggative head so often thrown at his head.']
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to("cuda")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)
input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[
"input_features"
]
input_features = input_features.to(device="cuda")
result = model.generate(input_features, return_timestamps=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)
assert decoded == EXPECTED_TEXT
@slow
def test_whisper_longform_multi_batch(self):
# fmt: off
EXPECTED_TEXT_1 = [" Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing a poster or near the fire, and the ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. a Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the tight-wing cloth that was the only germany war. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were, triggered his muscles into complete relaxation. Oily his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the mazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, pre-inscented and new to fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, return Calico. Where is my brother now? choir-dshaggy, in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh, no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe and knew any magic, or she'd have worked it before. I do not know, confess shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as Virgado used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgados discarded Ruby Crown, and holding in his hand to scepter, which Virgado had so often thrown at his head. head."]
EXPECTED_TEXT_2 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Burkett Foster's landscapes smile at one much in the same way that Mr. Carker."]
EXPECTED_TEXT_3 = [" possible. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-guards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath, next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting, he tells us, is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire. any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the titling cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even to soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Oily his heart and lungs worked on at a strong measured rate. He was in In reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, re-insunced it and knew the fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled and disgraced, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now? quared shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. And that's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confess Shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as we're good to have used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the thrown wearing ruggedos discarded ruby crown and holding in his hand to septor which Ruggato had so often thrown at his head."]
EXPECTED_TEXT_4 = [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter\'s manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton\'s work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell\'s pictures are a sort of up-gards and atom paintings, and Mason\'s exquisite idles are as national as a jingo poem. Mr. Birk at Foster\'s landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampoo or a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes the customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mantelboard. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon\'s body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered his muscles into complete relaxation. Oli\'s heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, The thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I\'m here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you\'re being a fool. out, through his silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry, and victory to the stronger. man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. had died before during the 20s and death during the last round was in some ways easier than defeat. Breathing deeply, Breon\'s softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent\'s face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that\'s rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with says he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn\'t work too hard, said Shaggy. He doesn\'t work at all. In fact, there\'s nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we\'ve turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The middle forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I\'m quite sure he didn\'t. That\'s funny, remarked Betsy thoughtfully. I don\'t believe Anne knew any magic, or she\'d have worked it before. I do not know, confess Shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Virgato used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgato\'s discarded ruby crown and holding in his hand to scepter which reggative head so often thrown at his head.']
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to("cuda")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)
audios = []
audios.append(one_audio[110000:])
audios.append(one_audio[:800000])
audios.append(one_audio[80000:])
audios.append(one_audio[:])
decoded_single = []
for audio in audios:
inputs = processor(audio, return_tensors="pt", truncation=False)
inputs = inputs.to(device="cuda")
result = model.generate(**inputs, return_timestamps=True)
decoded_single.append(processor.batch_decode(result, skip_special_tokens=True))
inputs = processor(
audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
)
inputs = inputs.to(device="cuda")
result = model.generate(**inputs, return_timestamps=True)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
# make sure single & batch is exactly the same
assert decoded_all[0:1] == decoded_single[0]
assert decoded_all[1:2] == decoded_single[1]
assert decoded_all[2:3] == decoded_single[2]
assert decoded_all[3:4] == decoded_single[3]
# exact match
assert decoded_all[0:1] == EXPECTED_TEXT_1
assert decoded_all[1:2] == EXPECTED_TEXT_2
assert decoded_all[2:3] == EXPECTED_TEXT_3
assert decoded_all[3:4] == EXPECTED_TEXT_4
@slow
def test_whisper_longform_multi_batch_hard(self):
# fmt: off
EXPECTED_TEXT = [
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!",
" Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!",
" Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them.",
" Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!",
" You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!",
" You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!",
" Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!",
" Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!",
]
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to("cuda")
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
num_samples = 8
audio = ds[:num_samples]["audio"]
audios = [x["array"] for x in audio]
decoded_single = []
for audio in audios:
inputs = processor(audio, return_tensors="pt", truncation=False, sampling_rate=16_000)
inputs = inputs.to(device="cuda")
result = model.generate(**inputs, return_timestamps=True)
decoded_single += processor.batch_decode(result, skip_special_tokens=True)
inputs = processor(
audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
)
inputs = inputs.to(device="cuda")
result = model.generate(**inputs, return_timestamps=True)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
for i in range(num_samples):
assert decoded_all[i] == decoded_single[i]
assert decoded_all[i] == EXPECTED_TEXT[i]
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
......
......@@ -16,7 +16,7 @@ import unittest
import numpy as np
import pytest
from datasets import load_dataset
from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import (
......@@ -329,16 +329,16 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{"text": " Conquered", "timestamp": (0.5, 1.2)},
{"text": " returned", "timestamp": (1.2, 1.64)},
{"text": " to", "timestamp": (1.64, 1.84)},
{"text": " its", "timestamp": (1.84, 2.02)},
{"text": " place", "timestamp": (2.02, 2.28)},
{"text": " amidst", "timestamp": (2.28, 2.78)},
{"text": " the", "timestamp": (2.78, 2.96)},
{"text": " tents.", "timestamp": (2.96, 3.48)},
'text': ' Conquered returned to its place amidst the tents.',
'chunks': [
{'text': ' Conquered', 'timestamp': (0.5, 1.2)},
{'text': ' returned', 'timestamp': (1.2, 1.64)},
{'text': ' to', 'timestamp': (1.64, 1.84)},
{'text': ' its', 'timestamp': (1.84, 2.02)},
{'text': ' place', 'timestamp': (2.02, 2.28)},
{'text': ' amidst', 'timestamp': (2.28, 2.8)},
{'text': ' the', 'timestamp': (2.8, 2.98)},
{'text': ' tents.', 'timestamp': (2.98, 3.48)},
],
},
)
......@@ -776,27 +776,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
'text': ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
'chunks': [
{'text': ' Mr.', 'timestamp': (0.38, 1.04)},
{'text': ' Quilter', 'timestamp': (1.04, 1.18)},
{'text': ' is', 'timestamp': (1.18, 1.44)},
{'text': ' the', 'timestamp': (1.44, 1.58)},
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
{'text': ' of', 'timestamp': (1.98, 2.3)},
{'text': ' the', 'timestamp': (2.3, 2.46)},
{'text': ' of', 'timestamp': (1.98, 2.32)},
{'text': ' the', 'timestamp': (2.32, 2.46)},
{'text': ' middle', 'timestamp': (2.46, 2.56)},
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
{'text': ' and', 'timestamp': (3.38, 3.52)},
{'text': ' we', 'timestamp': (3.52, 3.6)},
{'text': ' are', 'timestamp': (3.6, 3.72)},
{'text': ' classes,', 'timestamp': (2.56, 3.4)},
{'text': ' and', 'timestamp': (3.4, 3.54)},
{'text': ' we', 'timestamp': (3.54, 3.62)},
{'text': ' are', 'timestamp': (3.62, 3.72)},
{'text': ' glad', 'timestamp': (3.72, 4.0)},
{'text': ' to', 'timestamp': (4.0, 4.26)},
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
{'text': ' his', 'timestamp': (4.54, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
],
},
{'text': ' welcome', 'timestamp': (4.26, 4.56)},
{'text': ' his', 'timestamp': (4.56, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 5.84)}
]
}
)
# fmt: on
......@@ -1087,6 +1087,34 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_whisper_longform(self):
# fmt: off
EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!"""
# fmt: on
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to("cuda")
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
device="cuda:0",
)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[:1]["audio"]
result = pipe(audio)[0]["text"]
assert result == EXPECTED_RESULT
@require_torch
@slow
def test_chunking_and_timestamps(self):
......@@ -1355,7 +1383,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
out,
{
"chunks": [
{"text": "", "timestamp": (18.94, 0.0)},
{"text": "", "timestamp": (18.94, 0.02)},
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
],
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment