Unverified Commit 0a0a279e authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

🚨🚨[Whisper Tok] Update integration test (#29368)

* [Whisper Tok] Update integration test

* make style
parent e7b98370
...@@ -16,7 +16,7 @@ import unittest ...@@ -16,7 +16,7 @@ import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import require_jinja, slow from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -67,26 +67,26 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -67,26 +67,26 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname) tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)
tokens = tokenizer.tokenize("This is a test") tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", ", "test"]) self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġtest"])
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), tokenizer.convert_tokens_to_ids(tokens),
[5723, 307, 257, 220, 31636], [5723, 307, 257, 1500],
) )
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual( self.assertListEqual(
tokens, tokens,
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip ["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip
) # fmt: skip )
ids = tokenizer.convert_tokens_to_ids(tokens) ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 220, 11176, 307, 16720, 526, 13]) self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 341, 307, 16720, 526, 13])
back_tokens = tokenizer.convert_ids_to_tokens(ids) back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual( self.assertListEqual(
back_tokens, back_tokens,
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip ["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip
) # fmt: skip )
def test_tokenizer_slow_store_full_signature(self): def test_tokenizer_slow_store_full_signature(self):
pass pass
...@@ -499,25 +499,3 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): ...@@ -499,25 +499,3 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"] output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, []) self.assertEqual(output, [])
@require_jinja
def test_tokenization_for_chat(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment