"docs/vscode:/vscode.git/clone" did not exist on "039d8d65fc19ac74a8c7917233eb2828c46c0fa7"
Commit 03046285 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Map configs to models and tokenizers

parent 1fc855e4
...@@ -202,7 +202,7 @@ class AutoConfig: ...@@ -202,7 +202,7 @@ class AutoConfig:
return config_class.from_dict(config_dict, **kwargs) return config_class.from_dict(config_dict, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format( "Unrecognized model in {}. "
pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()) "Should have a `model_type` key in its config.json, or contain one of the following strings "
) "in its name: {}".format(pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()))
) )
...@@ -273,7 +273,7 @@ class PretrainedConfig(object): ...@@ -273,7 +273,7 @@ class PretrainedConfig(object):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
def __repr__(self): def __repr__(self):
return str(self.to_json_string()) return "{} {}".format(self.__class__.__name__, self.to_json_string())
def to_dict(self): def to_dict(self):
"""Serializes this instance to a Python dictionary.""" """Serializes this instance to a Python dictionary."""
......
This diff is collapsed.
This diff is collapsed.
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import logging import logging
from collections import OrderedDict
from typing import Dict, Type
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
...@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer ...@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer from .tokenization_roberta import RobertaTokenizer
from .tokenization_t5 import T5Tokenizer from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import XLNetTokenizer from .tokenization_xlnet import XLNetTokenizer
...@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer ...@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TOKENIZER_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedTokenizer]] = OrderedDict(
[
(T5Config, T5Tokenizer),
(DistilBertConfig, DistilBertTokenizer),
(AlbertConfig, AlbertTokenizer),
(CamembertConfig, CamembertTokenizer),
(RobertaConfig, XLMRobertaTokenizer),
(XLMRobertaConfig, RobertaTokenizer),
(BertConfig, BertTokenizer),
(OpenAIGPTConfig, OpenAIGPTTokenizer),
(GPT2Config, GPT2Tokenizer),
(TransfoXLConfig, TransfoXLTokenizer),
(XLNetConfig, XLNetTokenizer),
(XLMConfig, XLMTokenizer),
(CTRLConfig, CTRLTokenizer),
]
)
class AutoTokenizer(object): class AutoTokenizer(object):
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library that will be instantiated as one of the tokenizer classes of the library
...@@ -154,36 +176,13 @@ class AutoTokenizer(object): ...@@ -154,36 +176,13 @@ class AutoTokenizer(object):
if "bert-base-japanese" in pretrained_model_name_or_path: if "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if isinstance(config, T5Config): for config_class, tokenizer_class in TOKENIZER_MAPPING.items():
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) if isinstance(config, config_class):
elif isinstance(config, DistilBertConfig): return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, BertConfig):
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, GPT2Config):
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLNetConfig):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMConfig):
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CTRLConfig):
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} to build an AutoTokenizer.\n"
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Model type should be one of {}.".format(
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format( config.__class__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
pretrained_model_name_or_path
) )
) )
...@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase): ...@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
# no key string should be included in a later key string (typical failure case) # no key string should be included in a later key string (typical failure case)
keys = list(CONFIG_MAPPING.keys()) keys = list(CONFIG_MAPPING.keys())
for i, key in enumerate(keys): for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i+1:])) self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment