Unverified Commit 40d60e15 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix `tokenizer_class_from_name` for models with `-` in the name (#13251)



* fix tokenizer_class_from_name

* Update src/transformers/models/auto/tokenization_auto.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* add test
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 83bfdbdd
...@@ -229,7 +229,10 @@ def tokenizer_class_from_name(class_name: str): ...@@ -229,7 +229,10 @@ def tokenizer_class_from_name(class_name: str):
if class_name in tokenizers: if class_name in tokenizers:
break break
module = importlib.import_module(f".{module_name}", "transformers.models") if module_name == "openai-gpt":
module_name = "openai"
module = importlib.import_module(f".{module_name.replace('-', '_')}", "transformers.models")
return getattr(module, class_name) return getattr(module, class_name)
......
...@@ -29,7 +29,11 @@ from transformers import ( ...@@ -29,7 +29,11 @@ from transformers import (
RobertaTokenizerFast, RobertaTokenizerFast,
) )
from transformers.models.auto.configuration_auto import AutoConfig from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, get_tokenizer_config from transformers.models.auto.tokenization_auto import (
TOKENIZER_MAPPING,
get_tokenizer_config,
tokenizer_class_from_name,
)
from transformers.models.roberta.configuration_roberta import RobertaConfig from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
DUMMY_DIFF_TOKENIZER_IDENTIFIER, DUMMY_DIFF_TOKENIZER_IDENTIFIER,
...@@ -105,6 +109,24 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -105,6 +109,24 @@ class AutoTokenizerTest(unittest.TestCase):
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"): with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
self.assertFalse(issubclass(child_config, parent_config)) self.assertFalse(issubclass(child_config, parent_config))
def test_model_name_edge_cases_in_mappings(self):
# tests: https://github.com/huggingface/transformers/pull/13251
# 1. models with `-`, e.g. xlm-roberta -> xlm_roberta
# 2. models that don't remap 1-1 from model-name to model file, e.g., openai-gpt -> openai
tokenizers = TOKENIZER_MAPPING.values()
tokenizer_names = []
for slow_tok, fast_tok in tokenizers:
if slow_tok is not None:
tokenizer_names.append(slow_tok.__name__)
if fast_tok is not None:
tokenizer_names.append(fast_tok.__name__)
for tokenizer_name in tokenizer_names:
# must find the right class
tokenizer_class_from_name(tokenizer_name)
@require_tokenizers @require_tokenizers
def test_from_pretrained_use_fast_toggle(self): def test_from_pretrained_use_fast_toggle(self):
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer) self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment