Unverified Commit 89575b56 authored by Kamil Akesbi's avatar Kamil Akesbi Committed by GitHub
Browse files

Support generating with fallback for short form audio in Whisper (#30984)



* remove is_shortform

* adapt _retrieve_max_frames_and_seek for short_form

* return bos token in short and long form

* add decoder_input_ids to short form audios

* add eos token for  short form

* handle short form token_timestamps

* no need to return scores

* add is_shortform conditions

* handle when max_new_tokens is None - short form

* handle assistant decoding

* fix

* handle return_dict_in_generate

* handle split_by_batch for encoder_attentions attribute

* handle num_beams>1

* handle num_return_sequences>1 in generate_with_fallback

* handle num_return_sequences>1 with return_dict_in_generate=True

* raise error if max_new_tokens + decoder_inputs_ids > max_target_pos

* fix

* apply review suggestions

* fix

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fix

* logits for both short form and long form

* handle if logits_processor is None

* test

* apply review changes to num_return_sequences

* add _expand_variables_for_generation

* remove short form commented section

* update comments

* uncomment num_beams line in generate_with_fallback

* update assistant decoding

* handle return_segment with short form generation

* up

* fix output format is_shortform

* overwrite beam_sample test

* update _set_return_timestamps

* apply review suggestions

* apply review suggestions

* remove seek_outputs_short_form

* fix _stack_split_outputs

* fix stack dim in _stack_split_outputs

* update tests

* fix past_key_values + beam tests

* fix

* clean _expand_variables_for_generation

* make style

* fix slow tests

* make style

* max_length condition

* make style

* add slow tests for shortform fallback

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* apply review changes

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* up

* fix slow tests

* apply review suggestions

* update test

* make style

* small fix

* fix

* fix test_new_cache_format

* fix past_key_values

* fix

* make style

* fix slow tests

* fix

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 46835ec6
......@@ -65,6 +65,15 @@ if is_torch_available():
WhisperProcessor,
set_seed,
)
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
PhrasalConstraint,
)
from transformers.generation.logits_process import LogitsProcessor
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
......@@ -1539,6 +1548,241 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_longform_generate_multi_batch_cond_prev(self):
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
def test_beam_sample_generate_dict_output(self):
# We overwrite test_beam_sample_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = WhisperForConditionalGeneration(config).to(torch_device).eval()
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
beam_kwargs = self._get_beam_kwargs()
# With Whisper, we can only perform a beam search if the temperature is set to 0.
logits_warper_kwargs["temperature"] = 0
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"])
def test_beam_search_generate_dict_output(self):
# We overwrite test_beam_search_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
# With Whisper, we can only perform a beam search if the temperature is set to 0.
logits_process_kwargs["temperature"] = 0
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
def test_beam_search_generate_dict_outputs_use_cache(self):
# We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self._check_outputs(
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
)
def test_group_beam_search_generate_dict_output(self):
# We overwrite test_group_beam_search_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_diverse_beam_kwargs()
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
# Sample constraints
min_id = 3
max_id = model.config.vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs = self._get_constrained_beam_kwargs()
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
)
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
......@@ -2680,6 +2924,55 @@ class WhisperModelIntegrationTests(unittest.TestCase):
assert decoded == EXPECTED_TEXT
@slow
def test_whisper_shortform_single_batch_prev_cond(self):
# fmt: off
EXPECTED_TEXT = [" Folks, I spend a lot of time right over there, night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing and the most topical antilock breaks and power steering pain, Stakingly stitching, leather seating so soft, it would make JD power and her associate blush. If you were to create the luxury sedan that is my nightly model, but sometimes— you're sometimes, folks— I lurched the consciousness and the back of an abandoned school bus"]
EXPECTED_TEXT1 = [" Folks, I spend a lot of time right over there night after night after, actually. Carefully selecting for you the day's noisiest, most aerodynamic headlines, stress testing, and the most topical, anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school"]
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to(torch_device)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
dataset = ds.cast_column("audio", Audio(sampling_rate=16000))
one_audio = dataset[1]["audio"]["array"]
input_features = processor(one_audio, return_tensors="pt", sampling_rate=16_000)["input_features"]
input_features = input_features.to(device=torch_device)
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": -1.0,
}
torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
assert decoded == EXPECTED_TEXT
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.3,
"temperature": (0.0, 0.2),
"compression_ratio_threshold": 1,
"condition_on_prev_tokens": False,
"logprob_threshold": -1.0,
}
torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
assert decoded == EXPECTED_TEXT1
@slow
def test_whisper_longform_single_batch_beam(self):
# fmt: off
......@@ -2931,6 +3224,57 @@ class WhisperModelIntegrationTests(unittest.TestCase):
elif isinstance(EXPECTED_TEXT[i], tuple):
assert decoded_all[i] in EXPECTED_TEXT[i]
@slow
def test_whisper_shortform_multi_batch_hard_prev_cond(self):
# Without this set here, this test may fail if it is run with other tests (say, `test_tiny_*`). It's unclear
# why other tests may affect this tests: it seems some random operations are beyond the scene.
set_seed(0)
# fmt: off
EXPECTED_TEXT = [
' Mr. Kfilter is the apostle of the Middle Classes and we are glad to welcome his gospel.',
" Nor is Mr. Qilter's manner less interesting than his matter.",
' He tells us that at this festive season of the year, with Christmas and roce beef, looming before us, similarly drawn from eating and its results occur most readily to the mind.',
' He has grabbed those with her surfered trigger late and his work is really a great after all, and can discover it in it but little of Rocky Ithaka.',
" L'Neile's pictures are a sort of upguards and add-um paintings, and Maessin's exquisite Itals are a national as a jingo poem. Mr. Birkett Foster's landscapes smiled at one much in the same way that Mr. Carcher used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slapper in the back, before he says,",
' It is obviously unnecessary for us, to point out how luminous these criticisms are, how delicate and expression.',
' On the general principles of art and Mr. Kriltor rights with equal lucidity.',
' Painting, he tells us is of a different quality to mathematics and finish in art is adding more effect.',
]
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to(torch_device)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
num_samples = 8
audio = ds[:num_samples]["audio"]
audios = [x["array"] for x in audio]
inputs = processor(
audios,
return_tensors="pt",
sampling_rate=16_000,
)
inputs = inputs.to(device=torch_device)
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": -1.0,
}
result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True)
for i in range(num_samples):
if isinstance(EXPECTED_TEXT[i], str):
assert decoded_all[i] == EXPECTED_TEXT[i]
@slow
def test_whisper_longform_no_speech_detection(self):
# fmt: off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment