Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
03046285
Commit
03046285
authored
Jan 13, 2020
by
Julien Chaumond
Browse files
Map configs to models and tokenizers
parent
1fc855e4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
322 additions
and
458 deletions
+322
-458
src/transformers/configuration_auto.py
src/transformers/configuration_auto.py
+3
-3
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+3
-3
src/transformers/modeling_auto.py
src/transformers/modeling_auto.py
+132
-209
src/transformers/modeling_tf_auto.py
src/transformers/modeling_tf_auto.py
+154
-212
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+29
-30
tests/test_configuration_auto.py
tests/test_configuration_auto.py
+1
-1
No files found.
src/transformers/configuration_auto.py
View file @
03046285
...
...
@@ -202,7 +202,7 @@ class AutoConfig:
return
config_class
.
from_dict
(
config_dict
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model i
dentifier in {}. Should have a `model_type` key in its config.json, or contain one of {}"
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
CONFIG_MAPPING
.
keys
())
)
"Unrecognized model i
n {}. "
"Should have a `model_type` key in its config.json, or contain one of the following strings "
"in its name: {}"
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
CONFIG_MAPPING
.
keys
())
)
)
src/transformers/configuration_utils.py
View file @
03046285
...
...
@@ -47,8 +47,8 @@ class PretrainedConfig(object):
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
``torchscript``: string, default `False`. Is the model used with Torchscript.
"""
pretrained_config_archive_map
=
{}
# type: Dict[str, str]
model_type
=
""
# type: str
pretrained_config_archive_map
=
{}
# type: Dict[str, str]
model_type
=
""
# type: str
def
__init__
(
self
,
**
kwargs
):
# Attributes with defaults
...
...
@@ -273,7 +273,7 @@ class PretrainedConfig(object):
return
self
.
__dict__
==
other
.
__dict__
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
return
"{} {}"
.
format
(
self
.
__class__
.
__name__
,
self
.
to_json_string
())
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
...
...
src/transformers/modeling_auto.py
View file @
03046285
...
...
@@ -17,7 +17,7 @@
import
logging
from
collections
import
OrderedDict
from
typing
import
Type
from
typing
import
Dict
,
Type
from
.configuration_auto
import
(
AlbertConfig
,
...
...
@@ -126,14 +126,14 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
for
key
,
value
,
in
pretrained_map
.
items
()
)
MODEL_MAPPING
:
Ordered
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
MODEL_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
T5Config
,
T5Model
),
(
DistilBertConfig
,
DistilBertModel
),
(
AlbertConfig
,
AlbertModel
),
(
CamembertConfig
,
CamembertModel
),
(
RobertaConfig
,
XLM
RobertaModel
),
(
XLMRobertaConfig
,
RobertaModel
),
(
RobertaConfig
,
RobertaModel
),
(
XLMRobertaConfig
,
XLM
RobertaModel
),
(
BertConfig
,
BertModel
),
(
OpenAIGPTConfig
,
OpenAIGPTModel
),
(
GPT2Config
,
GPT2Model
),
...
...
@@ -144,12 +144,53 @@ MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = Orde
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
:
OrderedDict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
MODEL_WITH_LM_HEAD_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
T5Config
,
T5WithLMHeadModel
),
(
DistilBertConfig
,
DistilBertForMaskedLM
),
(
AlbertConfig
,
AlbertForMaskedLM
),
(
CamembertConfig
,
CamembertForMaskedLM
),
(
RobertaConfig
,
RobertaForMaskedLM
),
(
XLMRobertaConfig
,
XLMRobertaForMaskedLM
),
(
BertConfig
,
BertForMaskedLM
),
(
OpenAIGPTConfig
,
OpenAIGPTLMHeadModel
),
(
GPT2Config
,
GPT2LMHeadModel
),
(
TransfoXLConfig
,
TransfoXLLMHeadModel
),
(
XLNetConfig
,
XLNetLMHeadModel
),
(
XLMConfig
,
XLMWithLMHeadModel
),
(
CTRLConfig
,
CTRLLMHeadModel
),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
DistilBertForSequenceClassification
),
(
AlbertConfig
,
AlbertForSequenceClassification
),
(
CamembertConfig
,
CamembertForSequenceClassification
),
(
RobertaConfig
,
RobertaForSequenceClassification
),
(
XLMRobertaConfig
,
XLMRobertaForSequenceClassification
),
(
BertConfig
,
BertForSequenceClassification
),
(
XLNetConfig
,
XLNetForSequenceClassification
),
(
XLMConfig
,
XLMForSequenceClassification
),
]
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
DistilBertForQuestionAnswering
),
(
AlbertConfig
,
AlbertForQuestionAnswering
),
(
BertConfig
,
BertForQuestionAnswering
),
(
XLNetConfig
,
XLNetForQuestionAnswering
),
(
XLMConfig
,
XLMForQuestionAnswering
),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
DistilBertForTokenClassification
),
(
CamembertConfig
,
CamembertForTokenClassification
),
(
RobertaConfig
,
XLM
RobertaForTokenClassification
),
(
XLMRobertaConfig
,
RobertaForTokenClassification
),
(
RobertaConfig
,
RobertaForTokenClassification
),
(
XLMRobertaConfig
,
XLM
RobertaForTokenClassification
),
(
BertConfig
,
BertForTokenClassification
),
(
XLNetConfig
,
XLNetForTokenClassification
),
]
...
...
@@ -218,7 +259,12 @@ class AutoModel(object):
for
config_class
,
model_class
in
MODEL_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_MAPPING
.
keys
())
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -309,10 +355,9 @@ class AutoModel(object):
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_MAPPING
.
keys
())
)
)
...
...
@@ -376,27 +421,15 @@ class AutoModelWithLMHead(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForMaskedLM
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForMaskedLM
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForMaskedLM
(
config
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
OpenAIGPTLMHeadModel
(
config
)
elif
isinstance
(
config
,
GPT2Config
):
return
GPT2LMHeadModel
(
config
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TransfoXLLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMWithLMHeadModel
(
config
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForMaskedLM
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
MODEL_WITH_LM_HEAD_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_WITH_LM_HEAD_MAPPING
.
keys
())
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -486,57 +519,13 @@ class AutoModelWithLMHead(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
T5WithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
OpenAIGPTLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
GPT2LMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TransfoXLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMWithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
MODEL_WITH_LM_HEAD_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_WITH_LM_HEAD_MAPPING
.
keys
())
)
)
...
...
@@ -591,23 +580,17 @@ class AutoModelForSequenceClassification(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
AlbertConfig
):
return
AlbertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForSequenceClassification
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForSequenceClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -693,43 +676,15 @@ class AutoModelForSequenceClassification(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
keys
()),
)
)
...
...
@@ -780,17 +735,18 @@ class AutoModelForQuestionAnswering(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
AlbertConfig
):
return
AlbertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForQuestionAnswering
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -870,30 +826,17 @@ class AutoModelForQuestionAnswering(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
keys
()),
)
)
...
...
@@ -923,19 +866,18 @@ class AutoModelForTokenClassification:
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
CamembertConfig
):
return
CamembertForTokenClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -1014,34 +956,15 @@ class AutoModelForTokenClassification:
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
CamembertConfig
):
return
CamembertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} for this kind of AutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
keys
()),
)
)
src/transformers/modeling_tf_auto.py
View file @
03046285
...
...
@@ -16,6 +16,8 @@
import
logging
from
collections
import
OrderedDict
from
typing
import
Dict
,
Type
from
.configuration_auto
import
(
AlbertConfig
,
...
...
@@ -70,6 +72,7 @@ from .modeling_tf_transfo_xl import (
TFTransfoXLLMHeadModel
,
TFTransfoXLModel
,
)
from
.modeling_tf_utils
import
TFPreTrainedModel
from
.modeling_tf_xlm
import
(
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TFXLMForQuestionAnsweringSimple
,
...
...
@@ -108,6 +111,65 @@ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
for
key
,
value
,
in
pretrained_map
.
items
()
)
TF_MODEL_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertModel
),
(
AlbertConfig
,
TFAlbertModel
),
(
RobertaConfig
,
TFRobertaModel
),
(
BertConfig
,
TFBertModel
),
(
OpenAIGPTConfig
,
TFOpenAIGPTModel
),
(
GPT2Config
,
TFGPT2Model
),
(
TransfoXLConfig
,
TFTransfoXLModel
),
(
XLNetConfig
,
TFXLNetModel
),
(
XLMConfig
,
TFXLMModel
),
(
CTRLConfig
,
TFCTRLModel
),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertForMaskedLM
),
(
AlbertConfig
,
TFAlbertForMaskedLM
),
(
RobertaConfig
,
TFRobertaForMaskedLM
),
(
BertConfig
,
TFBertForMaskedLM
),
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
),
(
GPT2Config
,
TFGPT2LMHeadModel
),
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
),
(
XLNetConfig
,
TFXLNetLMHeadModel
),
(
XLMConfig
,
TFXLMWithLMHeadModel
),
(
CTRLConfig
,
TFCTRLLMHeadModel
),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertForSequenceClassification
),
(
AlbertConfig
,
TFAlbertForSequenceClassification
),
(
RobertaConfig
,
TFRobertaForSequenceClassification
),
(
BertConfig
,
TFBertForSequenceClassification
),
(
XLNetConfig
,
TFXLNetForSequenceClassification
),
(
XLMConfig
,
TFXLMForSequenceClassification
),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
),
(
BertConfig
,
TFBertForQuestionAnswering
),
(
XLNetConfig
,
TFXLNetForQuestionAnsweringSimple
),
(
XLMConfig
,
TFXLMForQuestionAnsweringSimple
),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
TFPreTrainedModel
]]
=
OrderedDict
(
[
(
DistilBertConfig
,
TFDistilBertForTokenClassification
),
(
RobertaConfig
,
TFRobertaForTokenClassification
),
(
BertConfig
,
TFBertForTokenClassification
),
(
XLNetConfig
,
TFXLNetForTokenClassification
),
]
)
class
TFAutoModel
(
object
):
r
"""
...
...
@@ -165,25 +227,15 @@ class TFAutoModel(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = TFAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertModel
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaModel
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertModel
(
config
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
TFOpenAIGPTModel
(
config
)
elif
isinstance
(
config
,
GPT2Config
):
return
TFGPT2Model
(
config
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TFTransfoXLModel
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetModel
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMModel
(
config
)
elif
isinstance
(
config
,
CTRLConfig
):
return
TFCTRLModel
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
TF_MODEL_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_MAPPING
.
keys
())
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -266,39 +318,14 @@ class TFAutoModel(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
TFT5Model
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
TFAlbertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
TFOpenAIGPTModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
TFGPT2Model
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TFTransfoXLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
TFCTRLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_MAPPING
.
keys
())
)
)
...
...
@@ -358,25 +385,15 @@ class TFAutoModelWithLMHead(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForMaskedLM
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForMaskedLM
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForMaskedLM
(
config
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
TFOpenAIGPTLMHeadModel
(
config
)
elif
isinstance
(
config
,
GPT2Config
):
return
TFGPT2LMHeadModel
(
config
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TFTransfoXLLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMWithLMHeadModel
(
config
)
elif
isinstance
(
config
,
CTRLConfig
):
return
TFCTRLLMHeadModel
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
TF_MODEL_WITH_LM_HEAD_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_WITH_LM_HEAD_MAPPING
.
keys
())
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -464,55 +481,14 @@ class TFAutoModelWithLMHead(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
TFT5WithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
TFAlbertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
TFOpenAIGPTLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
TFGPT2LMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TFTransfoXLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMWithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
TFCTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_WITH_LM_HEAD_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_WITH_LM_HEAD_MAPPING
.
keys
())
)
)
...
...
@@ -563,17 +539,17 @@ class TFAutoModelForSequenceClassification(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
RobertaC
onfig
)
:
return
TFRobertaForSequenceClassification
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMForSequenceClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
)
)
for
config_class
,
model_class
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
items
(
):
if
isinstance
(
config
,
config_class
):
return
model_class
(
c
onfig
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -659,34 +635,16 @@ class TFAutoModelForSequenceClassification(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
TFAlbertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'xlnet', 'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
keys
()),
)
)
...
...
@@ -735,15 +693,17 @@ class TFAutoModelForQuestionAnswering(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForQuestionAnswering
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
raise
NotImplementedError
(
"TFXLNetForQuestionAnswering isn't implemented"
)
elif
isinstance
(
config
,
XLMConfig
):
raise
NotImplementedError
(
"TFXLMForQuestionAnswering isn't implemented"
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -828,26 +788,16 @@ class TFAutoModelForQuestionAnswering(object):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
TFBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForQuestionAnsweringSimple
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
TFXLMForQuestionAnsweringSimple
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'xlnet', 'xlm'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
.
keys
()),
)
)
...
...
@@ -876,15 +826,17 @@ class TFAutoModelForTokenClassification:
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = TFAutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if
isinstance
(
config
,
BertConfig
):
return
TFBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
for
config_class
,
model_class
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
keys
()),
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
@@ -962,24 +914,14 @@ class TFAutoModelForTokenClassification:
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
isinstance
(
config
,
BertConfig
):
return
TFBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
for
config_class
,
model_class
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'distilbert', 'roberta'"
.
format
(
pretrained_model_name_or_path
)
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
cls
.
__name__
,
", "
.
join
(
c
.
__name__
for
c
in
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
.
keys
()),
)
)
src/transformers/tokenization_auto.py
View file @
03046285
...
...
@@ -16,6 +16,8 @@
import
logging
from
collections
import
OrderedDict
from
typing
import
Dict
,
Type
from
.configuration_auto
import
(
AlbertConfig
,
...
...
@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_t5
import
T5Tokenizer
from
.tokenization_transfo_xl
import
TransfoXLTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_xlm_roberta
import
XLMRobertaTokenizer
from
.tokenization_xlnet
import
XLNetTokenizer
...
...
@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
logger
=
logging
.
getLogger
(
__name__
)
TOKENIZER_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedTokenizer
]]
=
OrderedDict
(
[
(
T5Config
,
T5Tokenizer
),
(
DistilBertConfig
,
DistilBertTokenizer
),
(
AlbertConfig
,
AlbertTokenizer
),
(
CamembertConfig
,
CamembertTokenizer
),
(
RobertaConfig
,
XLMRobertaTokenizer
),
(
XLMRobertaConfig
,
RobertaTokenizer
),
(
BertConfig
,
BertTokenizer
),
(
OpenAIGPTConfig
,
OpenAIGPTTokenizer
),
(
GPT2Config
,
GPT2Tokenizer
),
(
TransfoXLConfig
,
TransfoXLTokenizer
),
(
XLNetConfig
,
XLNetTokenizer
),
(
XLMConfig
,
XLMTokenizer
),
(
CTRLConfig
,
CTRLTokenizer
),
]
)
class
AutoTokenizer
(
object
):
r
""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
...
...
@@ -154,36 +176,13 @@ class AutoTokenizer(object):
if
"bert-base-japanese"
in
pretrained_model_name_or_path
:
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
T5Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
OpenAIGPTTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
GPT2Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TransfoXLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
for
config_class
,
tokenizer_class
in
TOKENIZER_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
tokenizer_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} to build an AutoTokenizer.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_MAPPING
.
keys
())
)
)
tests/test_configuration_auto.py
View file @
03046285
...
...
@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
# no key string should be included in a later key string (typical failure case)
keys
=
list
(
CONFIG_MAPPING
.
keys
())
for
i
,
key
in
enumerate
(
keys
):
self
.
assertFalse
(
any
(
key
in
later_key
for
later_key
in
keys
[
i
+
1
:]))
self
.
assertFalse
(
any
(
key
in
later_key
for
later_key
in
keys
[
i
+
1
:]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment