"vscode:/vscode.git/clone" did not exist on "d3eacbb8299161d21e007e7e3d42505dae741282"
Unverified Commit e7e6d181 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Move decoder id method to tokenizer (#20589)

parent 9ffbed26
......@@ -42,37 +42,7 @@ class WhisperProcessor(ProcessorMixin):
self._in_target_context_manager = False
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
forced_decoder_tokens = ""
if language is not None:
if f"<|{language}|>" not in self.tokenizer.additional_special_tokens:
raise ValueError(
f"{language} is not supported. The language should be one of the following: '<|en|>',"
" '<|zh|>', '<|de|>', '<|es|>', '<|ru|>', '<|ko|>', '<|fr|>', '<|ja|>', '<|pt|>', '<|tr|>',"
" '<|pl|>', '<|ca|>', '<|nl|>', '<|ar|>', '<|sv|>', '<|it|>', '<|id|>', '<|hi|>', '<|fi|>',"
" '<|vi|>', '<|iw|>', '<|uk|>', '<|el|>', '<|ms|>', '<|cs|>', '<|ro|>', '<|da|>', '<|hu|>',"
" '<|ta|>', '<|no|>', '<|th|>', '<|ur|>', '<|hr|>', '<|bg|>', '<|lt|>', '<|la|>', '<|mi|>',"
" '<|ml|>', '<|cy|>', '<|sk|>', '<|te|>', '<|fa|>', '<|lv|>', '<|bn|>', '<|sr|>', '<|az|>',"
" '<|sl|>', '<|kn|>', '<|et|>', '<|mk|>', '<|br|>', '<|eu|>', '<|is|>', '<|hy|>', '<|ne|>',"
" '<|mn|>', '<|bs|>', '<|kk|>', '<|sq|>', '<|sw|>', '<|gl|>', '<|mr|>', '<|pa|>', '<|si|>',"
" '<|km|>', '<|sn|>', '<|yo|>', '<|so|>', '<|af|>', '<|oc|>', '<|ka|>', '<|be|>', '<|tg|>',"
" '<|sd|>', '<|gu|>', '<|am|>', '<|yi|>', '<|lo|>', '<|uz|>', '<|fo|>', '<|ht|>', '<|ps|>',"
" '<|tk|>', '<|nn|>', '<|mt|>', '<|sa|>', '<|lb|>', '<|my|>', '<|bo|>', '<|tl|>', '<|mg|>',"
" '<|as|>', '<|tt|>', '<|haw|>', '<|ln|>', '<|ha|>', '<|ba|>', '<|jw|>', '<|su|>'"
)
forced_decoder_tokens += f"<|{language}|>"
if task is not None:
if f"<|{task}|>" not in self.tokenizer.additional_special_tokens:
raise ValueError(
f"'{task}' is not supported. The language should be in : {{'transcribe', 'translate'}}"
)
forced_decoder_tokens += f"<|{task}|>"
forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else ""
ids = self.tokenizer.encode(forced_decoder_tokens, add_special_tokens=False)
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)]
return forced_decoder_ids
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
def __call__(self, *args, **kwargs):
"""
......
......@@ -399,9 +399,13 @@ class WhisperTokenizer(PreTrainedTokenizer):
self.language = self.language.lower()
if self.language in TO_LANGUAGE_CODE:
language_id = TO_LANGUAGE_CODE[self.language]
elif self.language in TO_LANGUAGE_CODE.values():
language_id = self.language
else:
is_language_code = len(self.language) == 2
raise ValueError(
f"Unsupported language: {self.language}. Language should be in: {TO_LANGUAGE_CODE.keys()}"
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)
if self.task is not None:
......@@ -577,3 +581,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
if len(input_ids) > self.model_max_length:
input_ids = input_ids[-self.model_max_length :]
return input_ids
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
self.set_prefix_tokens(task=task, language=language, predict_timestamps=no_timestamps)
return self.prefix_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment