"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dc8e0019b7feacd546236dc3361efd05f28b9137"
Unverified Commit ed71c21d authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

[from_pretrained] Allow tokenizer_type ≠ model_type (#6995)

parent 03e363f9
...@@ -190,6 +190,7 @@ class PretrainedConfig(object): ...@@ -190,6 +190,7 @@ class PretrainedConfig(object):
self.num_labels = kwargs.pop("num_labels", 2) self.num_labels = kwargs.pop("num_labels", 2)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
self.prefix = kwargs.pop("prefix", None) self.prefix = kwargs.pop("prefix", None)
self.bos_token_id = kwargs.pop("bos_token_id", None) self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None)
......
...@@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available ...@@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown" DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
# Used to test Auto{Config, Model, Tokenizer} model_type detection. # Used to test Auto{Config, Model, Tokenizer} model_type detection.
......
...@@ -222,6 +222,17 @@ class AutoTokenizer: ...@@ -222,6 +222,17 @@ class AutoTokenizer:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
use_fast = kwargs.pop("use_fast", False) use_fast = kwargs.pop("use_fast", False)
if config.tokenizer_class is not None:
if use_fast and not config.tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
else:
tokenizer_class_candidate = config.tokenizer_class
tokenizer_class = globals().get(tokenizer_class_candidate)
if tokenizer_class is None:
raise ValueError("Tokenizer class {} does not exist or is not currently imported.")
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items(): for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
if tokenizer_class_fast and use_fast: if tokenizer_class_fast and use_fast:
......
...@@ -27,7 +27,13 @@ from transformers import ( ...@@ -27,7 +27,13 @@ from transformers import (
RobertaTokenizer, RobertaTokenizer,
RobertaTokenizerFast, RobertaTokenizerFast,
) )
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER # noqa: F401 from transformers.configuration_auto import AutoConfig
from transformers.configuration_roberta import RobertaConfig
from transformers.testing_utils import (
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
DUMMY_UNKWOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
)
from transformers.tokenization_auto import TOKENIZER_MAPPING from transformers.tokenization_auto import TOKENIZER_MAPPING
...@@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
self.assertEqual(tokenizer.vocab_size, 20) self.assertEqual(tokenizer.vocab_size, 20)
def test_tokenizer_from_tokenizer_class(self):
config = AutoConfig.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER)
self.assertIsInstance(config, RobertaConfig)
# Check that tokenizer_type ≠ model_type
tokenizer = AutoTokenizer.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER, config=config)
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(tokenizer.vocab_size, 12)
def test_tokenizer_identifier_with_correct_config(self): def test_tokenizer_identifier_with_correct_config(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased") tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment