Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
03046285
Commit
03046285
authored
Jan 13, 2020
by
Julien Chaumond
Browse files
Map configs to models and tokenizers
parent
1fc855e4
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
322 additions
and
458 deletions
+322
-458
src/transformers/configuration_auto.py
src/transformers/configuration_auto.py
+3
-3
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+3
-3
src/transformers/modeling_auto.py
src/transformers/modeling_auto.py
+132
-209
src/transformers/modeling_tf_auto.py
src/transformers/modeling_tf_auto.py
+154
-212
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+29
-30
tests/test_configuration_auto.py
tests/test_configuration_auto.py
+1
-1
No files found.
src/transformers/configuration_auto.py
View file @
03046285
...
...
@@ -202,7 +202,7 @@ class AutoConfig:
return
config_class
.
from_dict
(
config_dict
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model i
dentifier in {}. Should have a `model_type` key in its config.json, or contain one of {}"
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
CONFIG_MAPPING
.
keys
())
)
"Unrecognized model i
n {}. "
"Should have a `model_type` key in its config.json, or contain one of the following strings "
"in its name: {}"
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
CONFIG_MAPPING
.
keys
())
)
)
src/transformers/configuration_utils.py
View file @
03046285
...
...
@@ -47,8 +47,8 @@ class PretrainedConfig(object):
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
``torchscript``: string, default `False`. Is the model used with Torchscript.
"""
pretrained_config_archive_map
=
{}
# type: Dict[str, str]
model_type
=
""
# type: str
pretrained_config_archive_map
=
{}
# type: Dict[str, str]
model_type
=
""
# type: str
def
__init__
(
self
,
**
kwargs
):
# Attributes with defaults
...
...
@@ -273,7 +273,7 @@ class PretrainedConfig(object):
return
self
.
__dict__
==
other
.
__dict__
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
return
"{} {}"
.
format
(
self
.
__class__
.
__name__
,
self
.
to_json_string
())
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
...
...
src/transformers/modeling_auto.py
View file @
03046285
This diff is collapsed.
Click to expand it.
src/transformers/modeling_tf_auto.py
View file @
03046285
This diff is collapsed.
Click to expand it.
src/transformers/tokenization_auto.py
View file @
03046285
...
...
@@ -16,6 +16,8 @@
import
logging
from
collections
import
OrderedDict
from
typing
import
Dict
,
Type
from
.configuration_auto
import
(
AlbertConfig
,
...
...
@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_t5
import
T5Tokenizer
from
.tokenization_transfo_xl
import
TransfoXLTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_xlm_roberta
import
XLMRobertaTokenizer
from
.tokenization_xlnet
import
XLNetTokenizer
...
...
@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
logger
=
logging
.
getLogger
(
__name__
)
TOKENIZER_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedTokenizer
]]
=
OrderedDict
(
[
(
T5Config
,
T5Tokenizer
),
(
DistilBertConfig
,
DistilBertTokenizer
),
(
AlbertConfig
,
AlbertTokenizer
),
(
CamembertConfig
,
CamembertTokenizer
),
(
RobertaConfig
,
XLMRobertaTokenizer
),
(
XLMRobertaConfig
,
RobertaTokenizer
),
(
BertConfig
,
BertTokenizer
),
(
OpenAIGPTConfig
,
OpenAIGPTTokenizer
),
(
GPT2Config
,
GPT2Tokenizer
),
(
TransfoXLConfig
,
TransfoXLTokenizer
),
(
XLNetConfig
,
XLNetTokenizer
),
(
XLMConfig
,
XLMTokenizer
),
(
CTRLConfig
,
CTRLTokenizer
),
]
)
class
AutoTokenizer
(
object
):
r
""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
...
...
@@ -154,36 +176,13 @@ class AutoTokenizer(object):
if
"bert-base-japanese"
in
pretrained_model_name_or_path
:
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
T5Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
OpenAIGPTTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
GPT2Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TransfoXLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
for
config_class
,
tokenizer_class
in
TOKENIZER_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
tokenizer_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} to build an AutoTokenizer.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_MAPPING
.
keys
())
)
)
tests/test_configuration_auto.py
View file @
03046285
...
...
@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
# no key string should be included in a later key string (typical failure case)
keys
=
list
(
CONFIG_MAPPING
.
keys
())
for
i
,
key
in
enumerate
(
keys
):
self
.
assertFalse
(
any
(
key
in
later_key
for
later_key
in
keys
[
i
+
1
:]))
self
.
assertFalse
(
any
(
key
in
later_key
for
later_key
in
keys
[
i
+
1
:]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment