Unverified Commit 65a926e8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Whisper] Refactor forced_decoder_ids & prompt ids (#28687)



* up

* Fix more

* Correct more

* Fix more tests

* fix fast tests

* Fix more

* fix more

* push all files

* finish all

* make style

* Fix timestamp wrap

* make style

* make style

* up

* up

* up

* Fix lang detection behavior

* Fix lang detection behavior

* Add lang detection test

* Fix lang detection behavior

* make style

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* better error message

* make style tests

* add warning

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent f9f1f2ac
...@@ -1451,6 +1451,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -1451,6 +1451,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# Original model wasn't trained with timestamps and has incorrect generation config # Original model wasn't trained with timestamps and has incorrect generation config
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2") pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
# the audio is 4 seconds long
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset") audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
out = pipe( out = pipe(
...@@ -1460,11 +1461,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -1460,11 +1461,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual( self.assertEqual(
out, out,
{ {
"chunks": [
{"text": "", "timestamp": (18.94, 0.02)},
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
],
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
"chunks": [{"timestamp": (0.58, None), "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं"}],
}, },
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment