Unverified Commit 5658e749 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[whisper] fix short-form output type (#32178)

* [whisper] fix short-form output type

* add test

* make style

* update long-form tests

* fixes

* last fix

* finalise test
parent 85a1269e
......@@ -498,7 +498,7 @@ class WhisperGenerationMixin:
# 3. Make sure generation config is correctly set
# Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
self._set_return_outputs(
return_dict_in_generate = self._set_return_outputs(
return_dict_in_generate=return_dict_in_generate,
return_token_timestamps=return_token_timestamps,
logprob_threshold=logprob_threshold,
......@@ -732,7 +732,7 @@ class WhisperGenerationMixin:
else:
outputs = sequences
if generation_config.return_dict_in_generate:
if return_dict_in_generate and generation_config.return_dict_in_generate:
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
if num_return_sequences > 1:
......@@ -1109,18 +1109,20 @@ class WhisperGenerationMixin:
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
if return_dict_in_generate is None:
return_dict_in_generate = generation_config.return_dict_in_generate
else:
generation_config.return_dict_in_generate = return_dict_in_generate
generation_config.return_token_timestamps = return_token_timestamps
if return_token_timestamps:
return_dict_in_generate = True
generation_config.return_dict_in_generate = True
generation_config.output_attentions = True
generation_config.output_scores = True
if logprob_threshold is not None:
return_dict_in_generate = True
generation_config.return_dict_in_generate = True
generation_config.output_scores = True
generation_config.return_dict_in_generate = return_dict_in_generate
return return_dict_in_generate
def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
if not is_shortform:
......
......@@ -26,6 +26,7 @@ import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from parameterized import parameterized
import transformers
from transformers import WhisperConfig
......@@ -72,6 +73,7 @@ if is_torch_available():
BeamSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateEncoderDecoderOutput,
PhrasalConstraint,
)
from transformers.generation.logits_process import LogitsProcessor
......@@ -1820,6 +1822,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@parameterized.expand([(True,), (False,)])
def test_generate_output_type(self, return_dict_in_generate):
expected_output_type = GenerateEncoderDecoderOutput if return_dict_in_generate else torch.Tensor
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs()
model = model_class(config).to(torch_device).eval()
# short-form generation without fallback
pred_ids = model.generate(**inputs, return_dict_in_generate=return_dict_in_generate)
assert isinstance(pred_ids, expected_output_type)
# short-form generation with fallback
pred_ids = model.generate(
**inputs,
logprob_threshold=-1.0,
temperature=[0.0, 0.1],
return_dict_in_generate=return_dict_in_generate,
)
assert isinstance(pred_ids, expected_output_type)
@require_torch
@require_torchaudio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment