Unverified Commit 06d48806 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper Tokenizer] Make more user-friendly (#19921)



* [Whisper Tokenizer] Make more user-friendly

* use property

* make indexing rigorous

* small clean-up

* tests

* skip seq2seq tests

* remove multilingual arg

* reorder args

* collapse to one function
Co-authored-by: default avatarArthurZucker <arthur@huggingface.co>

* option to override attributes
Co-authored-by: default avatarArthurZucker <arthur@huggingface.co>

* add to docs

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make comment more clear
Co-authored-by: default avatarsgugger <sylvain@huggingface.co>

* don't add special tokens in get_decoder_prompt_ids

* add test for set_prefix_tokens
Co-authored-by: default avatarArthurZucker <arthur@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarsgugger <sylvain@huggingface.co>
parent 790ff254
...@@ -39,6 +39,7 @@ The original code can be found [here](https://github.com/openai/whisper). ...@@ -39,6 +39,7 @@ The original code can be found [here](https://github.com/openai/whisper).
## WhisperTokenizer ## WhisperTokenizer
[[autodoc]] WhisperTokenizer [[autodoc]] WhisperTokenizer
- set_prefix_tokens
- build_inputs_with_special_tokens - build_inputs_with_special_tokens
- get_special_tokens_mask - get_special_tokens_mask
- create_token_type_ids_from_sequences - create_token_type_ids_from_sequences
......
...@@ -70,7 +70,7 @@ class WhisperProcessor(ProcessorMixin): ...@@ -70,7 +70,7 @@ class WhisperProcessor(ProcessorMixin):
forced_decoder_tokens += f"<|{task}|>" forced_decoder_tokens += f"<|{task}|>"
forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else "" forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else ""
ids = self.tokenizer.encode(forced_decoder_tokens) ids = self.tokenizer.encode(forced_decoder_tokens, add_special_tokens=False)
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)] forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)]
return forced_decoder_ids return forced_decoder_ids
......
...@@ -89,9 +89,130 @@ def get_pairs(word): ...@@ -89,9 +89,130 @@ def get_pairs(word):
return pairs return pairs
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
TASK_IDS = ["translate", "transcribe"]
class WhisperTokenizer(PreTrainedTokenizer): class WhisperTokenizer(PreTrainedTokenizer):
""" """
Construct an Whisper tokenizer. Construct a Whisper tokenizer.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
the superclass for more information regarding such methods. the superclass for more information regarding such methods.
...@@ -109,16 +230,22 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -109,16 +230,22 @@ class WhisperTokenizer(PreTrainedTokenizer):
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead. token instead.
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): bos_token (`str`, *optional*, defaults to `"<|startoftranscript|>"`):
The beginning of sequence token. The beginning of sequence token.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token. The end of sequence token.
add_prefix_space (`bool`, *optional*, defaults to `False`): add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. other word.
add_bos_token (`bool`, *optional*, defaults to `False`): language (`str`, *optional*):
Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as The language of the transcription text. The corresponding language id token is appended to the start of the
any other word. sequence for multilingual speech recognition and speech translation tasks, e.g. for Spanish the token
`"<|es|>"` is appended to the start of sequence. This should be used for multilingual fine-tuning only.
task (`str`, *optional*):
Task identifier to append at the start of sequence (if any). This should be used for mulitlingual
fine-tuning, with `"transcribe"` for speech recognition and `"translate"` for speech translation.
predict_timestamps (`bool`, *optional*, defaults to `False`):
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -133,11 +260,13 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -133,11 +260,13 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer_file=None, normalizer_file=None,
errors="replace", errors="replace",
unk_token="<|endoftext|>", unk_token="<|endoftext|>",
bos_token="<|endoftext|>", bos_token="<|startoftranscript|>",
eos_token="<|endoftext|>", eos_token="<|endoftext|>",
pad_token=None, pad_token=None,
add_prefix_space=False, add_prefix_space=False,
add_bos_token=False, language=None,
task=None,
predict_timestamps=False,
**kwargs **kwargs
): ):
...@@ -152,10 +281,8 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -152,10 +281,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
eos_token=eos_token, eos_token=eos_token,
pad_token=pad_token, pad_token=pad_token,
add_prefix_space=add_prefix_space, add_prefix_space=add_prefix_space,
add_bos_token=add_bos_token,
**kwargs, **kwargs,
) )
self.add_bos_token = add_bos_token
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
...@@ -179,6 +306,10 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -179,6 +306,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.language = language
self.task = task
self.predict_timestamps = predict_timestamps
def get_vocab(self): def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder) vocab.update(self.added_tokens_encoder)
...@@ -231,27 +362,76 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -231,27 +362,76 @@ class WhisperTokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.build_inputs_with_special_tokens with GPT2 -> Whisper def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None):
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """
if self.add_bos_token: Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to
bos_token_ids = [self.bos_token_id] update the prefix tokens as required when fine-tuning. Example:
else:
bos_token_ids = []
output = bos_token_ids + token_ids_0 ```python
>>> # instantiate the tokenizer and set the prefix token to Spanish
>>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
>>> # now switch the prefix token from Spanish to French
>>> tokenizer.set_prefix_tokens(language="french")
```
if token_ids_1 is None: Args:
return output language (`str`, *optional*, defaults to `None`):
The language of the transcription text.
task (`str`, *optional*, defaults to `None`):
Task identifier to append at the start of sequence (if any).
predict_timestamps (`bool`, *optional*, defaults to `None`):
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
"""
self.language = language if language is not None else self.language
self.task = task if task is not None else self.task
self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps
@property
def prefix_tokens(self) -> List[int]:
all_special_ids = self.all_special_ids
bos_token_id = all_special_ids[-106]
translate_token_id = all_special_ids[-6]
transcribe_token_id = all_special_ids[-5]
notimestamps_token_id = all_special_ids[-1]
langs = tuple(LANGUAGES.keys())
if self.language is not None:
self.language = self.language.lower()
if self.language in TO_LANGUAGE_CODE:
language_id = TO_LANGUAGE_CODE[self.language]
else:
raise ValueError(
f"Unsupported language: {self.language}. Language should be in: {TO_LANGUAGE_CODE.keys()}"
)
return output + bos_token_ids + token_ids_1 if self.task is not None:
if self.task not in TASK_IDS:
raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
bos_sequence = [bos_token_id]
if self.language is not None:
bos_sequence.append(bos_token_id + 1 + langs.index(language_id))
if self.task is not None:
bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
if not self.predict_timestamps:
bos_sequence.append(notimestamps_token_id)
return bos_sequence
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + [self.eos_token_id]
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id]
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_special_tokens_mask with GPT2 -> Whisper # Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask
def get_special_tokens_mask( def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]: ) -> List[int]:
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. special tokens using the tokenizer `prepare_for_model` method.
Args: Args:
token_ids_0 (`List[int]`): token_ids_0 (`List[int]`):
...@@ -264,19 +444,17 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -264,19 +444,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
Returns: Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
""" """
if already_has_special_tokens: if already_has_special_tokens:
return super().get_special_tokens_mask( return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
) )
if not self.add_bos_token: prefix_ones = [1] * len(self.prefix_tokens)
return super().get_special_tokens_mask( suffix_ones = [1]
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
)
if token_ids_1 is None: if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper
def _tokenize(self, text): def _tokenize(self, text):
......
...@@ -20,14 +20,20 @@ from transformers.testing_utils import slow ...@@ -20,14 +20,20 @@ from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
EN_CODE = 50258 ES_CODE = 50262
ES_CODE = 50256 EN_CODE = 50259
END_OF_TRANSCRIPT = 50257
START_OF_TRANSCRIPT = 50258
TRANSLATE = 50358
TRANSCRIBE = 50359
NOTIMESTAMPS = 50363
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = WhisperTokenizer tokenizer_class = WhisperTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
test_sentencepiece = False test_sentencepiece = False
test_seq2seq = False
def setUp(self): def setUp(self):
super().setUp() super().setUp()
...@@ -101,13 +107,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -101,13 +107,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en" checkpoint_name = "openai/whisper-small.en"
transcript = (
"'<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Nor is Mr. Quilters manner less interesting"
" than his matter.<|endoftext|>'"
)
clean_transcript = " Nor is Mr. Quilters manner less interesting than his matter."
french_text = "Bonjour! Il me semble que Mrs Quilters n'était pas présente"
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name) cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name)
...@@ -115,15 +114,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): ...@@ -115,15 +114,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
def test_tokenizer_equivalence(self): def test_tokenizer_equivalence(self):
text = "다람쥐 헌 쳇바퀴에 타고파" text = "다람쥐 헌 쳇바퀴에 타고파"
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="ko") multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="korean")
gpt2_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en") monolingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
gpt2_tokens = gpt2_tokenizer.encode(text) monolingual_tokens = monolingual_tokenizer.encode(text, add_special_tokens=False)
multilingual_tokens = multilingual_tokenizer.encode(text) multilingual_tokens = multilingual_tokenizer.encode(text, add_special_tokens=False)
assert gpt2_tokenizer.decode(gpt2_tokens) == text assert monolingual_tokenizer.decode(monolingual_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens) assert len(monolingual_tokens) > len(multilingual_tokens)
# fmt: off # fmt: off
EXPECTED_ENG = [ EXPECTED_ENG = [
...@@ -138,35 +137,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): ...@@ -138,35 +137,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
] ]
# fmt: on # fmt: on
self.assertListEqual(gpt2_tokens, EXPECTED_ENG) self.assertListEqual(monolingual_tokens, EXPECTED_ENG)
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI) self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
def test_tokenizer_special(self): def test_tokenizer_special(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en") multilingual_tokenizer = WhisperTokenizer.from_pretrained(
text = "<|startoftranscript|>Hey! How are you feeling? J'ai l'impression que 郷さん est prêt<|endoftext|>" "openai/whisper-tiny", language="english", task="transcribe"
)
text = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
multilingual_tokens = multilingual_tokenizer.encode(text) multilingual_tokens = multilingual_tokenizer.encode(text)
# fmt: off # fmt: off
# format: <|startoftranscript|> <|lang-id|> <|task|> <|notimestamps|> ... transcription ids ... <|endoftext|>
EXPECTED_MULTI = [ EXPECTED_MULTI = [
50257, 10814, 0, 1374, 389, 345, 4203, 30, 449, 6, START_OF_TRANSCRIPT, EN_CODE, TRANSCRIBE, NOTIMESTAMPS, 7057, 0, 1012, 366, 291,
1872, 300, 6, 11011, 2234, 8358, 16268, 225, 115, 43357, 2633, 30, 508, 6, 1301, 287, 6, 36107, 631, 220, 11178,
22174, 1556, 778, 25792, 83, 50256 115, 15567, 871, 44393, END_OF_TRANSCRIPT
] ]
EXPECTED_SPECIAL_TEXT = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>Hey! How are you feeling? "
"J'ai l'impression que 郷さん est prêt<|endoftext|>"
)
# fmt: on # fmt: on
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI) self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
self.assertEqual(text, multilingual_tokenizer.decode(multilingual_tokens)) special_transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=False)
self.assertEqual(special_transcript, EXPECTED_SPECIAL_TEXT)
transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True) transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True)
self.assertEqual(transcript, text)
EXPECTED_JAP = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
self.assertEqual(transcript, EXPECTED_JAP)
def test_vocab_size(self): def test_vocab_size(self):
self.assertEqual(self.tokenizer.vocab_size, 50257) self.assertEqual(self.tokenizer.vocab_size, 50257)
# Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py
def test_tokenizer_decode_ignores_language_codes(self): def test_tokenizer_decode_ignores_language_codes(self):
self.assertIn(ES_CODE, self.tokenizer.all_special_ids) self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2] generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
...@@ -176,15 +182,48 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): ...@@ -176,15 +182,48 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
self.assertNotIn(self.tokenizer.eos_token, result) self.assertNotIn(self.tokenizer.eos_token, result)
def test_batch_encoding(self): def test_batch_encoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en") multilingual_tokenizer = WhisperTokenizer.from_pretrained(
batch = ["<|en|><|notimestamps|>", "<|en|><|notimestamps|>I am sure that"] "openai/whisper-tiny", language="spanish", task="translate"
)
batch = ["El gato ", "El gato se sentó"]
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
# fmt: off # fmt: off
EXPECTED_MULTI = [ EXPECTED_MULTI = [
[50258, 50362, 50256, 50256, 50256, 50256], [START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 220,
[50258, 50362, 40, 716, 1654, 326] END_OF_TRANSCRIPT, END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 369,
2279, 812, END_OF_TRANSCRIPT]
] ]
# fmt: on # fmt: on
self.assertListEqual(batch_output, EXPECTED_MULTI) self.assertListEqual(batch_output, EXPECTED_MULTI)
def test_set_prefix_tokens(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-tiny", language="spanish", task="translate"
)
# change the language prefix token from Spanish to English
multilingual_tokenizer.set_prefix_tokens(language="english")
batch = ["the cat", "the cat sat"]
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
# fmt: off
EXPECTED_MULTI = [
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
3227, END_OF_TRANSCRIPT]
]
# fmt: on
self.assertListEqual(batch_output, EXPECTED_MULTI)
def test_batch_encoding_decoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
batch = ["hola güey", "que onda"]
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
self.assertListEqual(batch, transcription)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment