Unverified Commit 74fb524e authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Fix decoder ids methods (#20599)

* [Whisper] Fix decoder ids methods

* enum property
parent ef0f85cd
......@@ -583,5 +583,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
return input_ids
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
self.set_prefix_tokens(task=task, language=language, predict_timestamps=no_timestamps)
return self.prefix_tokens
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(self.prefix_tokens)]
return forced_decoder_ids
......@@ -26,6 +26,11 @@ if is_speech_available():
from transformers import WhisperFeatureExtractor, WhisperProcessor
START_OF_TRANSCRIPT = 50257
TRANSCRIBE = 50358
NOTIMESTAMPS = 50362
@require_torch
@require_torchaudio
@require_sentencepiece
......@@ -128,3 +133,17 @@ class WhisperProcessorTest(unittest.TestCase):
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
def test_get_decoder_prompt_ids(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", no_timestamps=True)
self.assertIsInstance(forced_decoder_ids, list)
for ids in forced_decoder_ids:
self.assertIsInstance(ids, (list, tuple))
expected_ids = [START_OF_TRANSCRIPT, TRANSCRIBE, NOTIMESTAMPS]
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment