Commit 03046285 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Map configs to models and tokenizers

parent 1fc855e4
......@@ -202,7 +202,7 @@ class AutoConfig:
return config_class.from_dict(config_dict, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format(
pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys())
)
"Unrecognized model in {}. "
"Should have a `model_type` key in its config.json, or contain one of the following strings "
"in its name: {}".format(pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()))
)
......@@ -47,8 +47,8 @@ class PretrainedConfig(object):
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
``torchscript``: string, default `False`. Is the model used with Torchscript.
"""
pretrained_config_archive_map = {} # type: Dict[str, str]
model_type = "" # type: str
pretrained_config_archive_map = {} # type: Dict[str, str]
model_type = "" # type: str
def __init__(self, **kwargs):
# Attributes with defaults
......@@ -273,7 +273,7 @@ class PretrainedConfig(object):
return self.__dict__ == other.__dict__
def __repr__(self):
return str(self.to_json_string())
return "{} {}".format(self.__class__.__name__, self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
......
This diff is collapsed.
This diff is collapsed.
......@@ -16,6 +16,8 @@
import logging
from collections import OrderedDict
from typing import Dict, Type
from .configuration_auto import (
AlbertConfig,
......@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import XLNetTokenizer
......@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
logger = logging.getLogger(__name__)
TOKENIZER_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedTokenizer]] = OrderedDict(
[
(T5Config, T5Tokenizer),
(DistilBertConfig, DistilBertTokenizer),
(AlbertConfig, AlbertTokenizer),
(CamembertConfig, CamembertTokenizer),
(RobertaConfig, XLMRobertaTokenizer),
(XLMRobertaConfig, RobertaTokenizer),
(BertConfig, BertTokenizer),
(OpenAIGPTConfig, OpenAIGPTTokenizer),
(GPT2Config, GPT2Tokenizer),
(TransfoXLConfig, TransfoXLTokenizer),
(XLNetConfig, XLNetTokenizer),
(XLMConfig, XLMTokenizer),
(CTRLConfig, CTRLTokenizer),
]
)
class AutoTokenizer(object):
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
......@@ -154,36 +176,13 @@ class AutoTokenizer(object):
if "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if isinstance(config, T5Config):
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, DistilBertConfig):
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, BertConfig):
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, GPT2Config):
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLNetConfig):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMConfig):
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CTRLConfig):
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
for config_class, tokenizer_class in TOKENIZER_MAPPING.items():
if isinstance(config, config_class):
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
"Unrecognized configuration class {} to build an AutoTokenizer.\n"
"Model type should be one of {}.".format(
config.__class__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
)
)
......@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
# no key string should be included in a later key string (typical failure case)
keys = list(CONFIG_MAPPING.keys())
for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i+1:]))
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment