Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
03046285
Commit
03046285
authored
Jan 13, 2020
by
Julien Chaumond
Browse files
Map configs to models and tokenizers
parent
1fc855e4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
322 additions
and
458 deletions
+322
-458
src/transformers/configuration_auto.py
src/transformers/configuration_auto.py
+3
-3
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+3
-3
src/transformers/modeling_auto.py
src/transformers/modeling_auto.py
+132
-209
src/transformers/modeling_tf_auto.py
src/transformers/modeling_tf_auto.py
+154
-212
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+29
-30
tests/test_configuration_auto.py
tests/test_configuration_auto.py
+1
-1
No files found.
src/transformers/configuration_auto.py
View file @
03046285
...
...
@@ -202,7 +202,7 @@ class AutoConfig:
return
config_class
.
from_dict
(
config_dict
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model i
dentifier in {}. Should have a `model_type` key in its config.json, or contain one of {}"
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
CONFIG_MAPPING
.
keys
())
)
"Unrecognized model i
n {}. "
"Should have a `model_type` key in its config.json, or contain one of the following strings "
"in its name: {}"
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
CONFIG_MAPPING
.
keys
())
)
)
src/transformers/configuration_utils.py
View file @
03046285
...
...
@@ -47,8 +47,8 @@ class PretrainedConfig(object):
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
``torchscript``: string, default `False`. Is the model used with Torchscript.
"""
pretrained_config_archive_map
=
{}
# type: Dict[str, str]
model_type
=
""
# type: str
pretrained_config_archive_map
=
{}
# type: Dict[str, str]
model_type
=
""
# type: str
def
__init__
(
self
,
**
kwargs
):
# Attributes with defaults
...
...
@@ -273,7 +273,7 @@ class PretrainedConfig(object):
return
self
.
__dict__
==
other
.
__dict__
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
return
"{} {}"
.
format
(
self
.
__class__
.
__name__
,
self
.
to_json_string
())
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
...
...
src/transformers/modeling_auto.py
View file @
03046285
...
...
@@ -17,7 +17,7 @@
import
logging
from
collections
import
OrderedDict
from
typing
import
Type
from
typing
import
Dict
,
Type
from
.configuration_auto
import
(
AlbertConfig
,
...
...
@@ -126,14 +126,14 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
for
key
,
value
,
in
pretrained_map
.
items
()
)
MODEL_MAPPING
:
Ordered
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
MODEL_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
T5Config
,
T5Model
),
(
DistilBertConfig
,
DistilBertModel
),
(
AlbertConfig
,
AlbertModel
),
(
CamembertConfig
,
CamembertModel
),
(
RobertaConfig
,
XLM
RobertaModel
),
(
XLMRobertaConfig
,
RobertaModel
),
(
RobertaConfig
,
RobertaModel
),
(
XLMRobertaConfig
,
XLM
RobertaModel
),
(
BertConfig
,
BertModel
),
(
OpenAIGPTConfig
,
OpenAIGPTModel
),
(
GPT2Config
,
GPT2Model
),
...
...
@@ -144,12 +144,53 @@ MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = Orde
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
:
OrderedDict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
MODEL_WITH_LM_HEAD_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
T5Config
,
T5WithLMHeadModel
),
(
DistilBertConfig
,
DistilBertForMaskedLM
),
(
AlbertConfig
,
AlbertForMaskedLM
),
(
CamembertConfig
,
CamembertForMaskedLM
),
(
RobertaConfig
,
RobertaForMaskedLM
),
(
XLMRobertaConfig
,
XLMRobertaForMaskedLM
),
(
BertConfig
,
BertForMaskedLM
),
(
OpenAIGPTConfig
,
OpenAIGPTLMHeadModel
),
(
GPT2Config
,
GPT2LMHeadModel
),
(
TransfoXLConfig
,
TransfoXLLMHeadModel
),
(
XLNetConfig
,
XLNetLMHeadModel
),
(
XLMConfig
,
XLMWithLMHeadModel
),
(
CTRLConfig
,
CTRLLMHeadModel
),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
DistilBertForSequenceClassification
),
(
AlbertConfig
,
AlbertForSequenceClassification
),
(
CamembertConfig
,
CamembertForSequenceClassification
),
(
RobertaConfig
,
RobertaForSequenceClassification
),
(
XLMRobertaConfig
,
XLMRobertaForSequenceClassification
),
(
BertConfig
,
BertForSequenceClassification
),
(
XLNetConfig
,
XLNetForSequenceClassification
),
(
XLMConfig
,
XLMForSequenceClassification
),
]
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
DistilBertForQuestionAnswering
),
(
AlbertConfig
,
AlbertForQuestionAnswering
),
(
BertConfig
,
BertForQuestionAnswering
),
(
XLNetConfig
,
XLNetForQuestionAnswering
),
(
XLMConfig
,
XLMForQuestionAnswering
),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
DistilBertForTokenClassification
),
(
CamembertConfig
,
CamembertForTokenClassification
),
(
RobertaConfig
,
XLM
RobertaForTokenClassification
),
(
XLMRobertaConfig
,
RobertaForTokenClassification
),
(
RobertaConfig
,
RobertaForTokenClassification
),
(
XLMRobertaConfig
,
XLM
RobertaForTokenClassification
),
(
BertConfig
,
BertForTokenClassification
),
(
XLNetConfig
,
XLNetForTokenClassification
),
]
...
...
@@ -218,7 +259,12 @@ class AutoModel(object):
for
config_class
,
model_class
in
MODEL_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_MAPPING
.
keys
())
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -309,10 +355,9 @@ class AutoModel(object):
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_MAPPING
.
keys
())
)
)
...
...
@@ -376,27 +421,15 @@ class AutoModelWithLMHead(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForMaskedLM
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForMaskedLM
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForMaskedLM
(
config
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
OpenAIGPTLMHeadModel
(
config
)
elif
isinstance
(
config
,
GPT2Config
):
return
GPT2LMHeadModel
(
config
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TransfoXLLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMWithLMHeadModel
(
config
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForMaskedLM
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
MODEL_WITH_LM_HEAD_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_WITH_LM_HEAD_MAPPING
.
keys
())
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -486,57 +519,13 @@ class AutoModelWithLMHead(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
T5WithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
OpenAIGPTLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
GPT2LMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TransfoXLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMWithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
MODEL_WITH_LM_HEAD_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_WITH_LM_HEAD_MAPPING
.
keys
())
)
)
...
...
@@ -591,23 +580,17 @@ class AutoModelForSequenceClassification(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
AlbertConfig
):
return
AlbertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForSequenceClassification
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForSequenceClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -693,43 +676,15 @@ class AutoModelForSequenceClassification(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
keys
()),
)
)
...
...
@@ -780,17 +735,18 @@ class AutoModelForQuestionAnswering(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
AlbertConfig
):
return
AlbertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForQuestionAnswering
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -870,30 +826,17 @@ class AutoModelForQuestionAnswering(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
keys
()),
)
)
...
...
@@ -923,19 +866,18 @@ class AutoModelForTokenClassification:
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
CamembertConfig
):
return
CamembertForTokenClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -1014,34 +956,15 @@ class AutoModelForTokenClassification:
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
CamembertConfig
):
return
CamembertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
keys
()),
)
)
src/transformers/modeling_tf_auto.py
View file @
03046285
...
...
@@ -16,6 +16,8 @@
import
logging
from
collections
import
OrderedDict
from
typing
import
Dict
,
Type
from
.configuration_auto
import
(
AlbertConfig
,
...
...
@@ -70,6 +72,7 @@ from .modeling_tf_transfo_xl import (
TFTransfoXLLMHeadModel
,
TFTransfoXLModel
,
)
from
.modeling_tf_utils
import
TFPreTrainedModel
from
.modeling_tf_xlm
import
(
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TFXLMForQuestionAnsweringSimple
,
...
...
@@ -108,6 +111,65 @@ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
for
key
,
value
,
in
pretrained_map
.
items
()
)
TF_MODEL_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertModel
),
(
AlbertConfig
,
TFAlbertModel
),
(
RobertaConfig
,
TFRobertaModel
),
(
BertConfig
,
TFBertModel
),
(
OpenAIGPTConfig
,
TFOpenAIGPTModel
),
(
GPT2Config
,
TFGPT2Model
),
(
TransfoXLConfig
,
TFTransfoXLModel
),
(
XLNetConfig
,
TFXLNetModel
),
(
XLMConfig
,
TFXLMModel
),
(
CTRLConfig
,
TFCTRLModel
),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertForMaskedLM
),
(
AlbertConfig
,
TFAlbertForMaskedLM
),
(
RobertaConfig
,
TFRobertaForMaskedLM
),
(
BertConfig
,
TFBertForMaskedLM
),
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
),
(
GPT2Config
,
TFGPT2LMHeadModel
),
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
),
(
XLNetConfig
,
TFXLNetLMHeadModel
),
(
XLMConfig
,
TFXLMWithLMHeadModel
),
(
CTRLConfig
,
TFCTRLLMHeadModel
),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertForSequenceClassification
),
(
AlbertConfig
,
TFAlbertForSequenceClassification
),
(
RobertaConfig
,
TFRobertaForSequenceClassification
),
(
BertConfig
,
TFBertForSequenceClassification
),
(
XLNetConfig
,
TFXLNetForSequenceClassification
),
(
XLMConfig
,
TFXLMForSequenceClassification
),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
),
(
BertConfig
,
TFBertForQuestionAnswering
),
(
XLNetConfig
,
TFXLNetForQuestionAnsweringSimple
),
(
XLMConfig
,
TFXLMForQuestionAnsweringSimple
),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertForTokenClassification
),
(
RobertaConfig
,
TFRobertaForTokenClassification
),
(
BertConfig
,
TFBertForTokenClassification
),
(
XLNetConfig
,
TFXLNetForTokenClassification
),
]
)
class
TFAutoModel
(
object
):
r
"""
...
...
@@ -165,25 +227,15 @@ class TFAutoModel(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = TFAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertModel
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaModel
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertModel
(
config
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
TFOpenAIGPTModel
(
config
)
elif
isinstance
(
config
,
GPT2Config
):
return
TFGPT2Model
(
config
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TFTransfoXLModel
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetModel
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMModel
(
config
)
elif
isinstance
(
config
,
CTRLConfig
):
return
TFCTRLModel
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
TF_MODEL_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_MAPPING
.
keys
())
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -266,39 +318,14 @@ class TFAutoModel(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
TFT5Model
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
TFAlbertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
TFOpenAIGPTModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
TFGPT2Model
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TFTransfoXLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
TFCTRLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_MAPPING
.
keys
())
)
)
...
...
@@ -358,25 +385,15 @@ class TFAutoModelWithLMHead(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForMaskedLM
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForMaskedLM
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForMaskedLM
(
config
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
TFOpenAIGPTLMHeadModel
(
config
)
elif
isinstance
(
config
,
GPT2Config
):
return
TFGPT2LMHeadModel
(
config
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TFTransfoXLLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMWithLMHeadModel
(
config
)
elif
isinstance
(
config
,
CTRLConfig
):
return
TFCTRLLMHeadModel
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
TF_MODEL_WITH_LM_HEAD_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_WITH_LM_HEAD_MAPPING
.
keys
())
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -464,55 +481,14 @@ class TFAutoModelWithLMHead(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
TFT5WithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
TFAlbertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
TFOpenAIGPTLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
TFGPT2LMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TFTransfoXLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMWithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
TFCTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_WITH_LM_HEAD_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_WITH_LM_HEAD_MAPPING
.
keys
())
)
)
...
...
@@ -563,17 +539,17 @@ class TFAutoModelForSequenceClassification(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
RobertaC
onfig
)
:
return
TFRobertaForSequenceClassification
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMForSequenceClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
)
)
for
config_class
,
model_class
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
items
(
):
if
isinstance
(
config
,
config_class
):
return
model_class
(
c
onfig
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -659,34 +635,16 @@ class TFAutoModelForSequenceClassification(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
TFAlbertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'xlnet', 'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
keys
()),
)
)
...
...
@@ -735,15 +693,17 @@ class TFAutoModelForQuestionAnswering(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
raise
NotImplementedError
(
"TFXLNetForQuestionAnswering isn't implemented"
)
elif
isinstance
(
config
,
XLMConfig
):
raise
NotImplementedError
(
"TFXLMForQuestionAnswering isn't implemented"
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -828,26 +788,16 @@ class TFAutoModelForQuestionAnswering(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForQuestionAnsweringSimple
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMForQuestionAnsweringSimple
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'xlnet', 'xlm'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
keys
()),
)
)
...
...
@@ -876,15 +826,17 @@ class TFAutoModelForTokenClassification:
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = TFAutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
BertConfig
):
return
TFBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -962,24 +914,14 @@ class TFAutoModelForTokenClassification:
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
BertConfig
):
return
TFBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'distilbert', 'roberta'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
keys
()),
)
)
src/transformers/tokenization_auto.py
View file @
03046285
...
...
@@ -16,6 +16,8 @@
import
logging
from
collections
import
OrderedDict
from
typing
import
Dict
,
Type
from
.configuration_auto
import
(
AlbertConfig
,
...
...
@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_t5
import
T5Tokenizer
from
.tokenization_transfo_xl
import
TransfoXLTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_xlm_roberta
import
XLMRobertaTokenizer
from
.tokenization_xlnet
import
XLNetTokenizer
...
...
@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
logger
=
logging
.
getLogger
(
__name__
)
TOKENIZER_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedTokenizer
]]
=
OrderedDict
(
[
(
T5Config
,
T5Tokenizer
),
(
DistilBertConfig
,
DistilBertTokenizer
),
(
AlbertConfig
,
AlbertTokenizer
),
(
CamembertConfig
,
CamembertTokenizer
),
(
RobertaConfig
,
XLMRobertaTokenizer
),
(
XLMRobertaConfig
,
RobertaTokenizer
),
(
BertConfig
,
BertTokenizer
),
(
OpenAIGPTConfig
,
OpenAIGPTTokenizer
),
(
GPT2Config
,
GPT2Tokenizer
),
(
TransfoXLConfig
,
TransfoXLTokenizer
),
(
XLNetConfig
,
XLNetTokenizer
),
(
XLMConfig
,
XLMTokenizer
),
(
CTRLConfig
,
CTRLTokenizer
),
]
)
class
AutoTokenizer
(
object
):
r
""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
...
...
@@ -154,36 +176,13 @@ class AutoTokenizer(object):
if
"bert-base-japanese"
in
pretrained_model_name_or_path
:
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
T5Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
OpenAIGPTTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
GPT2Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TransfoXLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
for
config_class
,
tokenizer_class
in
TOKENIZER_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
tokenizer_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} to build an AutoTokenizer.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_MAPPING
.
keys
())
)
)
tests/test_configuration_auto.py
View file @
03046285
...
...
@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
# no key string should be included in a later key string (typical failure case)
keys
=
list
(
CONFIG_MAPPING
.
keys
())
for
i
,
key
in
enumerate
(
keys
):
self
.
assertFalse
(
any
(
key
in
later_key
for
later_key
in
keys
[
i
+
1
:]))
self
.
assertFalse
(
any
(
key
in
later_key
for
later_key
in
keys
[
i
+
1
:]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment