Unverified Commit c4ecd234 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix AutoTokenizer when no fast tokenizer is available (#13336)

* Fix AutoTokenizer when a tokenizer has no fast version

* Add test
parent ffecfea9
...@@ -229,12 +229,12 @@ def tokenizer_class_from_name(class_name: str): ...@@ -229,12 +229,12 @@ def tokenizer_class_from_name(class_name: str):
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers: if class_name in tokenizers:
break module_name = model_type_to_module_name(module_name)
module_name = model_type_to_module_name(module_name) module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
module = importlib.import_module(f".{module_name}", "transformers.models") return None
return getattr(module, class_name)
def get_tokenizer_config( def get_tokenizer_config(
......
...@@ -22,6 +22,7 @@ from transformers import ( ...@@ -22,6 +22,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
BertTokenizer, BertTokenizer,
BertTokenizerFast, BertTokenizerFast,
CTRLTokenizer,
GPT2Tokenizer, GPT2Tokenizer,
GPT2TokenizerFast, GPT2TokenizerFast,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
...@@ -162,6 +163,11 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -162,6 +163,11 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertIsInstance(tokenizer2, tokenizer.__class__) self.assertIsInstance(tokenizer2, tokenizer.__class__)
self.assertEqual(tokenizer2.vocab_size, 12) self.assertEqual(tokenizer2.vocab_size, 12)
def test_auto_tokenizer_fast_no_slow(self):
tokenizer = AutoTokenizer.from_pretrained("ctrl")
# There is no fast CTRL so this always gives us a slow tokenizer.
self.assertIsInstance(tokenizer, CTRLTokenizer)
def test_get_tokenizer_config(self): def test_get_tokenizer_config(self):
# Check we can load the tokenizer config of an online model. # Check we can load the tokenizer config of an online model.
config = get_tokenizer_config("bert-base-cased") config = get_tokenizer_config("bert-base-cased")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment