Unverified Commit deafc243 authored by Jonatan Kłosko's avatar Jonatan Kłosko Committed by GitHub
Browse files

Add WhisperTokenizerFast (#21222)



* Add WhisperTokenizerFast

* Fixup

* Up

* Up

* Improve tests

* Update src/transformers/models/whisper/tokenization_whisper_fast.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Keep stride in whisper pipelien test

* Remove unknown token special case

* Reduce vocabulary size in tests

* Fix vocab size assertion

* Sync copied changes from WhisperTokenizer

* Skip pipeline tests

* Update assertion

* Remove Whisper tokenizer dependency on sentencepiece

* Format

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 8b3db33a
...@@ -406,7 +406,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -406,7 +406,7 @@ Flax), PyTorch, and/or TensorFlow.
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ | | Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| Whisper | ✅ | | ✅ | ✅ | ✅ | | Whisper | ✅ | | ✅ | ✅ | ✅ |
| X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ | | X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
| X-MOD | ❌ | ❌ | ✅ | ❌ | ❌ | | X-MOD | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ✅ | ✅ | | XGLM | ✅ | ✅ | ✅ | ✅ | ✅ |
......
...@@ -45,6 +45,15 @@ The original code can be found [here](https://github.com/openai/whisper). ...@@ -45,6 +45,15 @@ The original code can be found [here](https://github.com/openai/whisper).
- create_token_type_ids_from_sequences - create_token_type_ids_from_sequences
- save_vocabulary - save_vocabulary
## WhisperTokenizerFast
[[autodoc]] WhisperTokenizerFast
- set_prefix_tokens
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary
## WhisperFeatureExtractor ## WhisperFeatureExtractor
[[autodoc]] WhisperFeatureExtractor [[autodoc]] WhisperFeatureExtractor
......
...@@ -739,6 +739,7 @@ else: ...@@ -739,6 +739,7 @@ else:
_import_structure["models.splinter"].append("SplinterTokenizerFast") _import_structure["models.splinter"].append("SplinterTokenizerFast")
_import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast") _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast")
_import_structure["models.t5"].append("T5TokenizerFast") _import_structure["models.t5"].append("T5TokenizerFast")
_import_structure["models.whisper"].append("WhisperTokenizerFast")
_import_structure["models.xglm"].append("XGLMTokenizerFast") _import_structure["models.xglm"].append("XGLMTokenizerFast")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast")
_import_structure["models.xlnet"].append("XLNetTokenizerFast") _import_structure["models.xlnet"].append("XLNetTokenizerFast")
...@@ -4278,6 +4279,7 @@ if TYPE_CHECKING: ...@@ -4278,6 +4279,7 @@ if TYPE_CHECKING:
from .models.splinter import SplinterTokenizerFast from .models.splinter import SplinterTokenizerFast
from .models.squeezebert import SqueezeBertTokenizerFast from .models.squeezebert import SqueezeBertTokenizerFast
from .models.t5 import T5TokenizerFast from .models.t5 import T5TokenizerFast
from .models.whisper import WhisperTokenizerFast
from .models.xglm import XGLMTokenizerFast from .models.xglm import XGLMTokenizerFast
from .models.xlm_roberta import XLMRobertaTokenizerFast from .models.xlm_roberta import XLMRobertaTokenizerFast
from .models.xlnet import XLNetTokenizerFast from .models.xlnet import XLNetTokenizerFast
......
...@@ -286,7 +286,7 @@ class GPT2Converter(Converter): ...@@ -286,7 +286,7 @@ class GPT2Converter(Converter):
bos = self.original_tokenizer.bos_token bos = self.original_tokenizer.bos_token
bos_token_id = self.original_tokenizer.bos_token_id bos_token_id = self.original_tokenizer.bos_token_id
tokenizer.post_processor = processors.TemplateProcessing( tokenizer.post_processor = processors.TemplateProcessing(
single=f"{bos}:0 $A:0", # token_type_id is 2 for Funnel transformer single=f"{bos}:0 $A:0",
pair=f"{bos}:0 $A:0 $B:1", pair=f"{bos}:0 $A:0 $B:1",
special_tokens=[ special_tokens=[
(bos, bos_token_id), (bos, bos_token_id),
...@@ -891,6 +891,42 @@ class T5Converter(SpmConverter): ...@@ -891,6 +891,42 @@ class T5Converter(SpmConverter):
) )
class WhisperConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
prefix_token_ids = self.original_tokenizer.prefix_tokens
prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
eos = self.original_tokenizer.eos_token
eos_token_id = self.original_tokenizer.eos_token_id
prefix_template = " ".join([f"{token}:0" for token in prefixes])
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{prefix_template} $A:0 {eos}:0",
pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
special_tokens=[
(eos, eos_token_id),
*zip(prefixes, prefix_token_ids),
],
)
return tokenizer
class BigBirdConverter(SpmConverter): class BigBirdConverter(SpmConverter):
def post_processor(self): def post_processor(self):
return processors.TemplateProcessing( return processors.TemplateProcessing(
...@@ -1127,6 +1163,7 @@ SLOW_TO_FAST_CONVERTERS = { ...@@ -1127,6 +1163,7 @@ SLOW_TO_FAST_CONVERTERS = {
"RoFormerTokenizer": RoFormerConverter, "RoFormerTokenizer": RoFormerConverter,
"SqueezeBertTokenizer": BertConverter, "SqueezeBertTokenizer": BertConverter,
"T5Tokenizer": T5Converter, "T5Tokenizer": T5Converter,
"WhisperTokenizer": WhisperConverter,
"XLMRobertaTokenizer": XLMRobertaConverter, "XLMRobertaTokenizer": XLMRobertaConverter,
"XLNetTokenizer": XLNetConverter, "XLNetTokenizer": XLNetConverter,
"SplinterTokenizer": SplinterConverter, "SplinterTokenizer": SplinterConverter,
......
...@@ -302,7 +302,7 @@ else: ...@@ -302,7 +302,7 @@ else:
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)), ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)), ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
("whisper", ("WhisperTokenizer" if is_sentencepiece_available() else None, None)), ("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
( (
"xglm", "xglm",
......
...@@ -18,6 +18,7 @@ from ...utils import ( ...@@ -18,6 +18,7 @@ from ...utils import (
_LazyModule, _LazyModule,
is_flax_available, is_flax_available,
is_tf_available, is_tf_available,
is_tokenizers_available,
is_torch_available, is_torch_available,
) )
...@@ -29,6 +30,13 @@ _import_structure = { ...@@ -29,6 +30,13 @@ _import_structure = {
"tokenization_whisper": ["WhisperTokenizer"], "tokenization_whisper": ["WhisperTokenizer"],
} }
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_whisper_fast"] = ["WhisperTokenizerFast"]
try: try:
if not is_torch_available(): if not is_torch_available():
...@@ -75,6 +83,14 @@ if TYPE_CHECKING: ...@@ -75,6 +83,14 @@ if TYPE_CHECKING:
from .processing_whisper import WhisperProcessor from .processing_whisper import WhisperProcessor
from .tokenization_whisper import WhisperTokenizer from .tokenization_whisper import WhisperTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_whisper_fast import WhisperTokenizerFast
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
This diff is collapsed.
...@@ -366,6 +366,13 @@ class T5TokenizerFast(metaclass=DummyObject): ...@@ -366,6 +366,13 @@ class T5TokenizerFast(metaclass=DummyObject):
requires_backends(self, ["tokenizers"]) requires_backends(self, ["tokenizers"])
class WhisperTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
class XGLMTokenizerFast(metaclass=DummyObject): class XGLMTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"] _backends = ["tokenizers"]
......
...@@ -79,7 +79,7 @@ class TFWhisperModelTester: ...@@ -79,7 +79,7 @@ class TFWhisperModelTester:
seq_length=60, seq_length=60,
is_training=True, is_training=True,
use_labels=False, use_labels=False,
vocab_size=99, vocab_size=200,
hidden_size=16, hidden_size=16,
num_hidden_layers=2, num_hidden_layers=2,
num_attention_heads=4, num_attention_heads=4,
......
...@@ -96,7 +96,7 @@ class WhisperModelTester: ...@@ -96,7 +96,7 @@ class WhisperModelTester:
seq_length=60, seq_length=60,
is_training=True, is_training=True,
use_labels=False, use_labels=False,
vocab_size=99, vocab_size=200,
hidden_size=16, hidden_size=16,
num_hidden_layers=2, num_hidden_layers=2,
num_attention_heads=4, num_attention_heads=4,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from transformers.models.whisper import WhisperTokenizer from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.testing_utils import slow from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -31,7 +31,8 @@ NOTIMESTAMPS = 50363 ...@@ -31,7 +31,8 @@ NOTIMESTAMPS = 50363
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = WhisperTokenizer tokenizer_class = WhisperTokenizer
test_rust_tokenizer = False rust_tokenizer_class = WhisperTokenizerFast
test_rust_tokenizer = True
test_sentencepiece = False test_sentencepiece = False
test_seq2seq = False test_seq2seq = False
...@@ -93,6 +94,17 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -93,6 +94,17 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
def test_tokenizer_slow_store_full_signature(self): def test_tokenizer_slow_store_full_signature(self):
pass pass
def test_tokenizer_fast_store_full_signature(self):
pass
def test_special_tokens_initialization(self):
# Whisper relies on specific additional special tokens, so we skip this
# general test. In particular, this test loads fast tokenizer from slow
# tokenizer, and the conversion uses prefix_tokens, where we reference
# additional special tokens by specific indices, hence overriding the
# list with less tokens leads to out of index error
pass
@slow @slow
def test_tokenizer_integration(self): def test_tokenizer_integration(self):
# fmt: off # fmt: off
......
...@@ -123,7 +123,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -123,7 +123,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
outputs = speech_recognizer(audio, return_timestamps=True) outputs = speech_recognizer(audio, return_timestamps=True)
self.assertIsInstance(outputs["chunks"], list) self.assertIsInstance(outputs["chunks"], list)
nb_chunks = len(outputs["chunks"]) nb_chunks = len(outputs["chunks"])
self.assertGreaterThan(nb_chunks, 0) self.assertGreater(nb_chunks, 0)
self.assertEqual( self.assertEqual(
outputs, outputs,
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment