Commit 03046285 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Map configs to models and tokenizers

parent 1fc855e4
...@@ -202,7 +202,7 @@ class AutoConfig: ...@@ -202,7 +202,7 @@ class AutoConfig:
return config_class.from_dict(config_dict, **kwargs) return config_class.from_dict(config_dict, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format( "Unrecognized model in {}. "
pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()) "Should have a `model_type` key in its config.json, or contain one of the following strings "
) "in its name: {}".format(pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()))
) )
...@@ -47,8 +47,8 @@ class PretrainedConfig(object): ...@@ -47,8 +47,8 @@ class PretrainedConfig(object):
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
``torchscript``: string, default `False`. Is the model used with Torchscript. ``torchscript``: string, default `False`. Is the model used with Torchscript.
""" """
pretrained_config_archive_map = {} # type: Dict[str, str] pretrained_config_archive_map = {} # type: Dict[str, str]
model_type = "" # type: str model_type = "" # type: str
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Attributes with defaults # Attributes with defaults
...@@ -273,7 +273,7 @@ class PretrainedConfig(object): ...@@ -273,7 +273,7 @@ class PretrainedConfig(object):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
def __repr__(self): def __repr__(self):
return str(self.to_json_string()) return "{} {}".format(self.__class__.__name__, self.to_json_string())
def to_dict(self): def to_dict(self):
"""Serializes this instance to a Python dictionary.""" """Serializes this instance to a Python dictionary."""
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Type from typing import Dict, Type
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
...@@ -126,14 +126,14 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict( ...@@ -126,14 +126,14 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
for key, value, in pretrained_map.items() for key, value, in pretrained_map.items()
) )
MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict( MODEL_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
[ [
(T5Config, T5Model), (T5Config, T5Model),
(DistilBertConfig, DistilBertModel), (DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel), (AlbertConfig, AlbertModel),
(CamembertConfig, CamembertModel), (CamembertConfig, CamembertModel),
(RobertaConfig, XLMRobertaModel), (RobertaConfig, RobertaModel),
(XLMRobertaConfig, RobertaModel), (XLMRobertaConfig, XLMRobertaModel),
(BertConfig, BertModel), (BertConfig, BertModel),
(OpenAIGPTConfig, OpenAIGPTModel), (OpenAIGPTConfig, OpenAIGPTModel),
(GPT2Config, GPT2Model), (GPT2Config, GPT2Model),
...@@ -144,12 +144,53 @@ MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = Orde ...@@ -144,12 +144,53 @@ MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = Orde
] ]
) )
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict( MODEL_WITH_LM_HEAD_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
[
(T5Config, T5WithLMHeadModel),
(DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForMaskedLM),
(CamembertConfig, CamembertForMaskedLM),
(RobertaConfig, RobertaForMaskedLM),
(XLMRobertaConfig, XLMRobertaForMaskedLM),
(BertConfig, BertForMaskedLM),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
(GPT2Config, GPT2LMHeadModel),
(TransfoXLConfig, TransfoXLLMHeadModel),
(XLNetConfig, XLNetLMHeadModel),
(XLMConfig, XLMWithLMHeadModel),
(CTRLConfig, CTRLLMHeadModel),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
[
(DistilBertConfig, DistilBertForSequenceClassification),
(AlbertConfig, AlbertForSequenceClassification),
(CamembertConfig, CamembertForSequenceClassification),
(RobertaConfig, RobertaForSequenceClassification),
(XLMRobertaConfig, XLMRobertaForSequenceClassification),
(BertConfig, BertForSequenceClassification),
(XLNetConfig, XLNetForSequenceClassification),
(XLMConfig, XLMForSequenceClassification),
]
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
[
(DistilBertConfig, DistilBertForQuestionAnswering),
(AlbertConfig, AlbertForQuestionAnswering),
(BertConfig, BertForQuestionAnswering),
(XLNetConfig, XLNetForQuestionAnswering),
(XLMConfig, XLMForQuestionAnswering),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
[ [
(DistilBertConfig, DistilBertForTokenClassification), (DistilBertConfig, DistilBertForTokenClassification),
(CamembertConfig, CamembertForTokenClassification), (CamembertConfig, CamembertForTokenClassification),
(RobertaConfig, XLMRobertaForTokenClassification), (RobertaConfig, RobertaForTokenClassification),
(XLMRobertaConfig, RobertaForTokenClassification), (XLMRobertaConfig, XLMRobertaForTokenClassification),
(BertConfig, BertForTokenClassification), (BertConfig, BertForTokenClassification),
(XLNetConfig, XLNetForTokenClassification), (XLNetConfig, XLNetForTokenClassification),
] ]
...@@ -218,7 +259,12 @@ class AutoModel(object): ...@@ -218,7 +259,12 @@ class AutoModel(object):
for config_class, model_class in MODEL_MAPPING.items(): for config_class, model_class in MODEL_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
return model_class(config) return model_class(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
)
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -309,10 +355,9 @@ class AutoModel(object): ...@@ -309,10 +355,9 @@ class AutoModel(object):
if isinstance(config, config_class): if isinstance(config, config_class):
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Model type should be one of {}.".format(
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format( config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
pretrained_model_name_or_path
) )
) )
...@@ -376,27 +421,15 @@ class AutoModelWithLMHead(object): ...@@ -376,27 +421,15 @@ class AutoModelWithLMHead(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, DistilBertConfig): for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
return DistilBertForMaskedLM(config) if isinstance(config, config_class):
elif isinstance(config, RobertaConfig): return model_class(config)
return RobertaForMaskedLM(config) raise ValueError(
elif isinstance(config, BertConfig): "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
return BertForMaskedLM(config) "Model type should be one of {}.".format(
elif isinstance(config, OpenAIGPTConfig): config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
return OpenAIGPTLMHeadModel(config) )
elif isinstance(config, GPT2Config): )
return GPT2LMHeadModel(config)
elif isinstance(config, TransfoXLConfig):
return TransfoXLLMHeadModel(config)
elif isinstance(config, XLNetConfig):
return XLNetLMHeadModel(config)
elif isinstance(config, XLMConfig):
return XLMWithLMHeadModel(config)
elif isinstance(config, CTRLConfig):
return CTRLLMHeadModel(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForMaskedLM(config)
raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -486,57 +519,13 @@ class AutoModelWithLMHead(object): ...@@ -486,57 +519,13 @@ class AutoModelWithLMHead(object):
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, T5Config): for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
return T5WithLMHeadModel.from_pretrained( if isinstance(config, config_class):
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
elif isinstance(config, DistilBertConfig):
return DistilBertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, AlbertConfig):
return AlbertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CamembertConfig):
return CamembertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return RobertaForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, GPT2Config):
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMWithLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CTRLConfig):
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Model type should be one of {}.".format(
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format( config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
pretrained_model_name_or_path
) )
) )
...@@ -591,23 +580,17 @@ class AutoModelForSequenceClassification(object): ...@@ -591,23 +580,17 @@ class AutoModelForSequenceClassification(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, AlbertConfig): for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
return AlbertForSequenceClassification(config) if isinstance(config, config_class):
elif isinstance(config, CamembertConfig): return model_class(config)
return CamembertForSequenceClassification(config) raise ValueError(
elif isinstance(config, DistilBertConfig): "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
return DistilBertForSequenceClassification(config) "Model type should be one of {}.".format(
elif isinstance(config, RobertaConfig): config.__class__,
return RobertaForSequenceClassification(config) cls.__name__,
elif isinstance(config, BertConfig): ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
return BertForSequenceClassification(config) )
elif isinstance(config, XLNetConfig): )
return XLNetForSequenceClassification(config)
elif isinstance(config, XLMConfig):
return XLMForSequenceClassification(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForSequenceClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -693,43 +676,15 @@ class AutoModelForSequenceClassification(object): ...@@ -693,43 +676,15 @@ class AutoModelForSequenceClassification(object):
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig): for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
return DistilBertForSequenceClassification.from_pretrained( if isinstance(config, config_class):
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
elif isinstance(config, AlbertConfig):
return AlbertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CamembertConfig):
return CamembertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return RobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format( "Model type should be one of {}.".format(
pretrained_model_name_or_path config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
) )
) )
...@@ -780,17 +735,18 @@ class AutoModelForQuestionAnswering(object): ...@@ -780,17 +735,18 @@ class AutoModelForQuestionAnswering(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, AlbertConfig): for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
return AlbertForQuestionAnswering(config) if isinstance(config, config_class):
elif isinstance(config, DistilBertConfig): return model_class(config)
return DistilBertForQuestionAnswering(config)
elif isinstance(config, BertConfig): raise ValueError(
return BertForQuestionAnswering(config) "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
elif isinstance(config, XLNetConfig): "Model type should be one of {}.".format(
return XLNetForQuestionAnswering(config) config.__class__,
elif isinstance(config, XLMConfig): cls.__name__,
return XLMForQuestionAnswering(config) ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
raise ValueError("Unrecognized configuration class {}".format(config)) )
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -870,30 +826,17 @@ class AutoModelForQuestionAnswering(object): ...@@ -870,30 +826,17 @@ class AutoModelForQuestionAnswering(object):
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig): for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
return DistilBertForQuestionAnswering.from_pretrained( if isinstance(config, config_class):
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
elif isinstance(config, AlbertConfig):
return AlbertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path) "Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
)
) )
...@@ -923,19 +866,18 @@ class AutoModelForTokenClassification: ...@@ -923,19 +866,18 @@ class AutoModelForTokenClassification:
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, CamembertConfig): for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
return CamembertForTokenClassification(config) if isinstance(config, config_class):
elif isinstance(config, DistilBertConfig): return model_class(config)
return DistilBertForTokenClassification(config)
elif isinstance(config, BertConfig): raise ValueError(
return BertForTokenClassification(config) "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
elif isinstance(config, XLNetConfig): "Model type should be one of {}.".format(
return XLNetForTokenClassification(config) config.__class__,
elif isinstance(config, RobertaConfig): cls.__name__,
return RobertaForTokenClassification(config) ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
elif isinstance(config, XLMRobertaConfig): )
return XLMRobertaForTokenClassification(config) )
raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -1014,34 +956,15 @@ class AutoModelForTokenClassification: ...@@ -1014,34 +956,15 @@ class AutoModelForTokenClassification:
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, CamembertConfig): for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
return CamembertForTokenClassification.from_pretrained( if isinstance(config, config_class):
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
elif isinstance(config, DistilBertConfig):
return DistilBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return RobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format( "Model type should be one of {}.".format(
pretrained_model_name_or_path config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
) )
) )
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import logging import logging
from collections import OrderedDict
from typing import Dict, Type
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
...@@ -70,6 +72,7 @@ from .modeling_tf_transfo_xl import ( ...@@ -70,6 +72,7 @@ from .modeling_tf_transfo_xl import (
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TFTransfoXLModel, TFTransfoXLModel,
) )
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_xlm import ( from .modeling_tf_xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP, TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
...@@ -108,6 +111,65 @@ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict( ...@@ -108,6 +111,65 @@ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
for key, value, in pretrained_map.items() for key, value, in pretrained_map.items()
) )
TF_MODEL_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
[
(DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel),
(RobertaConfig, TFRobertaModel),
(BertConfig, TFBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
(GPT2Config, TFGPT2Model),
(TransfoXLConfig, TFTransfoXLModel),
(XLNetConfig, TFXLNetModel),
(XLMConfig, TFXLMModel),
(CTRLConfig, TFCTRLModel),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
[
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(CTRLConfig, TFCTRLLMHeadModel),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
[
(DistilBertConfig, TFDistilBertForSequenceClassification),
(AlbertConfig, TFAlbertForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
(XLMConfig, TFXLMForSequenceClassification),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
[
(DistilBertConfig, TFDistilBertForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
(XLMConfig, TFXLMForQuestionAnsweringSimple),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
[
(DistilBertConfig, TFDistilBertForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(XLNetConfig, TFXLNetForTokenClassification),
]
)
class TFAutoModel(object): class TFAutoModel(object):
r""" r"""
...@@ -165,25 +227,15 @@ class TFAutoModel(object): ...@@ -165,25 +227,15 @@ class TFAutoModel(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = TFAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, DistilBertConfig): for config_class, model_class in TF_MODEL_MAPPING.items():
return TFDistilBertModel(config) if isinstance(config, config_class):
elif isinstance(config, RobertaConfig): return model_class(config)
return TFRobertaModel(config) raise ValueError(
elif isinstance(config, BertConfig): "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
return TFBertModel(config) "Model type should be one of {}.".format(
elif isinstance(config, OpenAIGPTConfig): config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys())
return TFOpenAIGPTModel(config) )
elif isinstance(config, GPT2Config): )
return TFGPT2Model(config)
elif isinstance(config, TransfoXLConfig):
return TFTransfoXLModel(config)
elif isinstance(config, XLNetConfig):
return TFXLNetModel(config)
elif isinstance(config, XLMConfig):
return TFXLMModel(config)
elif isinstance(config, CTRLConfig):
return TFCTRLModel(config)
raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -266,39 +318,14 @@ class TFAutoModel(object): ...@@ -266,39 +318,14 @@ class TFAutoModel(object):
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, T5Config): for config_class, model_class in TF_MODEL_MAPPING.items():
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) if isinstance(config, config_class):
elif isinstance(config, DistilBertConfig): return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
return TFDistilBertModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, AlbertConfig):
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, RobertaConfig):
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, BertConfig):
return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return TFOpenAIGPTModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, GPT2Config):
return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TFTransfoXLModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLMConfig):
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, CTRLConfig):
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Model type should be one of {}.".format(
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path) config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys())
)
) )
...@@ -358,25 +385,15 @@ class TFAutoModelWithLMHead(object): ...@@ -358,25 +385,15 @@ class TFAutoModelWithLMHead(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, DistilBertConfig): for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
return TFDistilBertForMaskedLM(config) if isinstance(config, config_class):
elif isinstance(config, RobertaConfig): return model_class(config)
return TFRobertaForMaskedLM(config) raise ValueError(
elif isinstance(config, BertConfig): "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
return TFBertForMaskedLM(config) "Model type should be one of {}.".format(
elif isinstance(config, OpenAIGPTConfig): config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_WITH_LM_HEAD_MAPPING.keys())
return TFOpenAIGPTLMHeadModel(config) )
elif isinstance(config, GPT2Config): )
return TFGPT2LMHeadModel(config)
elif isinstance(config, TransfoXLConfig):
return TFTransfoXLLMHeadModel(config)
elif isinstance(config, XLNetConfig):
return TFXLNetLMHeadModel(config)
elif isinstance(config, XLMConfig):
return TFXLMWithLMHeadModel(config)
elif isinstance(config, CTRLConfig):
return TFCTRLLMHeadModel(config)
raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -464,55 +481,14 @@ class TFAutoModelWithLMHead(object): ...@@ -464,55 +481,14 @@ class TFAutoModelWithLMHead(object):
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, T5Config): for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
return TFT5WithLMHeadModel.from_pretrained( if isinstance(config, config_class):
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
elif isinstance(config, DistilBertConfig):
return TFDistilBertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, AlbertConfig):
return TFAlbertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return TFRobertaForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return TFBertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, OpenAIGPTConfig):
return TFOpenAIGPTLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, GPT2Config):
return TFGPT2LMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, TransfoXLConfig):
return TFTransfoXLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return TFXLMWithLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CTRLConfig):
return TFCTRLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Model type should be one of {}.".format(
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path) config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_WITH_LM_HEAD_MAPPING.keys())
)
) )
...@@ -563,17 +539,17 @@ class TFAutoModelForSequenceClassification(object): ...@@ -563,17 +539,17 @@ class TFAutoModelForSequenceClassification(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, DistilBertConfig): for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
return TFDistilBertForSequenceClassification(config) if isinstance(config, config_class):
elif isinstance(config, RobertaConfig): return model_class(config)
return TFRobertaForSequenceClassification(config) raise ValueError(
elif isinstance(config, BertConfig): "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
return TFBertForSequenceClassification(config) "Model type should be one of {}.".format(
elif isinstance(config, XLNetConfig): config.__class__,
return TFXLNetForSequenceClassification(config) cls.__name__,
elif isinstance(config, XLMConfig): ", ".join(c.__name__ for c in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
return TFXLMForSequenceClassification(config) )
raise ValueError("Unrecognized configuration class {}".format(config)) )
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -659,34 +635,16 @@ class TFAutoModelForSequenceClassification(object): ...@@ -659,34 +635,16 @@ class TFAutoModelForSequenceClassification(object):
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig): for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
return TFDistilBertForSequenceClassification.from_pretrained( if isinstance(config, config_class):
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
elif isinstance(config, AlbertConfig):
return TFAlbertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return TFRobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return TFBertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return TFXLMForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"'distilbert', 'bert', 'xlnet', 'xlm', 'roberta'".format(pretrained_model_name_or_path) "Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
)
) )
...@@ -735,15 +693,17 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -735,15 +693,17 @@ class TFAutoModelForQuestionAnswering(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, DistilBertConfig): for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
return TFDistilBertForQuestionAnswering(config) if isinstance(config, config_class):
elif isinstance(config, BertConfig): return model_class(config)
return TFBertForQuestionAnswering(config) raise ValueError(
elif isinstance(config, XLNetConfig): "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
raise NotImplementedError("TFXLNetForQuestionAnswering isn't implemented") "Model type should be one of {}.".format(
elif isinstance(config, XLMConfig): config.__class__,
raise NotImplementedError("TFXLMForQuestionAnswering isn't implemented") cls.__name__,
raise ValueError("Unrecognized configuration class {}".format(config)) ", ".join(c.__name__ for c in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
)
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -828,26 +788,16 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -828,26 +788,16 @@ class TFAutoModelForQuestionAnswering(object):
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig): for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
return TFDistilBertForQuestionAnswering.from_pretrained( if isinstance(config, config_class):
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
elif isinstance(config, BertConfig):
return TFBertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetForQuestionAnsweringSimple.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return TFXLMForQuestionAnsweringSimple.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"'distilbert', 'bert', 'xlnet', 'xlm'".format(pretrained_model_name_or_path) "Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
)
) )
...@@ -876,15 +826,17 @@ class TFAutoModelForTokenClassification: ...@@ -876,15 +826,17 @@ class TFAutoModelForTokenClassification:
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = TFAutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
""" """
if isinstance(config, BertConfig): for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
return TFBertForTokenClassification(config) if isinstance(config, config_class):
elif isinstance(config, XLNetConfig): return model_class(config)
return TFXLNetForTokenClassification(config) raise ValueError(
elif isinstance(config, DistilBertConfig): "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
return TFDistilBertForTokenClassification(config) "Model type should be one of {}.".format(
elif isinstance(config, RobertaConfig): config.__class__,
return TFRobertaForTokenClassification(config) cls.__name__,
raise ValueError("Unrecognized configuration class {}".format(config)) ", ".join(c.__name__ for c in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
)
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -962,24 +914,14 @@ class TFAutoModelForTokenClassification: ...@@ -962,24 +914,14 @@ class TFAutoModelForTokenClassification:
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, BertConfig): for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
return TFBertForTokenClassification.from_pretrained( if isinstance(config, config_class):
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
elif isinstance(config, XLNetConfig):
return TFXLNetForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, DistilBertConfig):
return TFDistilBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return TFRobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"'bert', 'xlnet', 'distilbert', 'roberta'".format(pretrained_model_name_or_path) "Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
)
) )
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import logging import logging
from collections import OrderedDict
from typing import Dict, Type
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
...@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer ...@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer from .tokenization_roberta import RobertaTokenizer
from .tokenization_t5 import T5Tokenizer from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import XLNetTokenizer from .tokenization_xlnet import XLNetTokenizer
...@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer ...@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TOKENIZER_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedTokenizer]] = OrderedDict(
[
(T5Config, T5Tokenizer),
(DistilBertConfig, DistilBertTokenizer),
(AlbertConfig, AlbertTokenizer),
(CamembertConfig, CamembertTokenizer),
(RobertaConfig, XLMRobertaTokenizer),
(XLMRobertaConfig, RobertaTokenizer),
(BertConfig, BertTokenizer),
(OpenAIGPTConfig, OpenAIGPTTokenizer),
(GPT2Config, GPT2Tokenizer),
(TransfoXLConfig, TransfoXLTokenizer),
(XLNetConfig, XLNetTokenizer),
(XLMConfig, XLMTokenizer),
(CTRLConfig, CTRLTokenizer),
]
)
class AutoTokenizer(object): class AutoTokenizer(object):
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library that will be instantiated as one of the tokenizer classes of the library
...@@ -154,36 +176,13 @@ class AutoTokenizer(object): ...@@ -154,36 +176,13 @@ class AutoTokenizer(object):
if "bert-base-japanese" in pretrained_model_name_or_path: if "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if isinstance(config, T5Config): for config_class, tokenizer_class in TOKENIZER_MAPPING.items():
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) if isinstance(config, config_class):
elif isinstance(config, DistilBertConfig): return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, BertConfig):
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, GPT2Config):
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLNetConfig):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMConfig):
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CTRLConfig):
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized configuration class {} to build an AutoTokenizer.\n"
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Model type should be one of {}.".format(
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format( config.__class__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
pretrained_model_name_or_path
) )
) )
...@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase): ...@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
# no key string should be included in a later key string (typical failure case) # no key string should be included in a later key string (typical failure case)
keys = list(CONFIG_MAPPING.keys()) keys = list(CONFIG_MAPPING.keys())
for i, key in enumerate(keys): for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i+1:])) self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment