Unverified Commit 77382e91 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Fix forced decoder ids (#20652)

* [Whisper] Fix forced decoder ids

* fix test
parent 7c5eaf9e
...@@ -584,5 +584,10 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -584,5 +584,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps) self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(self.prefix_tokens)] # prefix tokens are of the form: <|startoftranscript|> <|lang_id|> <|task|> <|notimestamps|>
# we don't want to force the bos token at position 1, as this is the starting token
# when we generate, so we slice the prefix tokens to: <|lang_id|> <|task|> <|notimestamps|>
# to get the forced tokens
forced_tokens = self.prefix_tokens[1:]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
return forced_decoder_ids return forced_decoder_ids
...@@ -26,7 +26,6 @@ if is_speech_available(): ...@@ -26,7 +26,6 @@ if is_speech_available():
from transformers import WhisperFeatureExtractor, WhisperProcessor from transformers import WhisperFeatureExtractor, WhisperProcessor
START_OF_TRANSCRIPT = 50257
TRANSCRIBE = 50358 TRANSCRIBE = 50358
NOTIMESTAMPS = 50362 NOTIMESTAMPS = 50362
...@@ -145,5 +144,5 @@ class WhisperProcessorTest(unittest.TestCase): ...@@ -145,5 +144,5 @@ class WhisperProcessorTest(unittest.TestCase):
for ids in forced_decoder_ids: for ids in forced_decoder_ids:
self.assertIsInstance(ids, (list, tuple)) self.assertIsInstance(ids, (list, tuple))
expected_ids = [START_OF_TRANSCRIPT, TRANSCRIBE, NOTIMESTAMPS] expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids) self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment