Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
03046285
Commit
03046285
authored
Jan 13, 2020
by
Julien Chaumond
Browse files
Map configs to models and tokenizers
parent
1fc855e4
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
322 additions
and
458 deletions
+322
-458
src/transformers/configuration_auto.py
src/transformers/configuration_auto.py
+3
-3
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+3
-3
src/transformers/modeling_auto.py
src/transformers/modeling_auto.py
+132
-209
src/transformers/modeling_tf_auto.py
src/transformers/modeling_tf_auto.py
+154
-212
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+29
-30
tests/test_configuration_auto.py
tests/test_configuration_auto.py
+1
-1
No files found.
src/transformers/configuration_auto.py
View file @
03046285
...
...
@@ -202,7 +202,7 @@ class AutoConfig:
return
config_class
.
from_dict
(
config_dict
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model i
dentifier in {}. Should have a `model_type` key in its config.json, or contain one of {}"
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
CONFIG_MAPPING
.
keys
())
)
"Unrecognized model i
n {}. "
"Should have a `model_type` key in its config.json, or contain one of the following strings "
"in its name: {}"
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
CONFIG_MAPPING
.
keys
())
)
)
src/transformers/configuration_utils.py
View file @
03046285
...
...
@@ -273,7 +273,7 @@ class PretrainedConfig(object):
return
self
.
__dict__
==
other
.
__dict__
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
return
"{} {}"
.
format
(
self
.
__class__
.
__name__
,
self
.
to_json_string
())
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
...
...
src/transformers/modeling_auto.py
View file @
03046285
This diff is collapsed.
Click to expand it.
src/transformers/modeling_tf_auto.py
View file @
03046285
This diff is collapsed.
Click to expand it.
src/transformers/tokenization_auto.py
View file @
03046285
...
...
@@ -16,6 +16,8 @@
import
logging
from
collections
import
OrderedDict
from
typing
import
Dict
,
Type
from
.configuration_auto
import
(
AlbertConfig
,
...
...
@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_t5
import
T5Tokenizer
from
.tokenization_transfo_xl
import
TransfoXLTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_xlm_roberta
import
XLMRobertaTokenizer
from
.tokenization_xlnet
import
XLNetTokenizer
...
...
@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
logger
=
logging
.
getLogger
(
__name__
)
TOKENIZER_MAPPING
:
Dict
[
Type
[
PretrainedConfig
],
Type
[
PreTrainedTokenizer
]]
=
OrderedDict
(
[
(
T5Config
,
T5Tokenizer
),
(
DistilBertConfig
,
DistilBertTokenizer
),
(
AlbertConfig
,
AlbertTokenizer
),
(
CamembertConfig
,
CamembertTokenizer
),
(
RobertaConfig
,
XLMRobertaTokenizer
),
(
XLMRobertaConfig
,
RobertaTokenizer
),
(
BertConfig
,
BertTokenizer
),
(
OpenAIGPTConfig
,
OpenAIGPTTokenizer
),
(
GPT2Config
,
GPT2Tokenizer
),
(
TransfoXLConfig
,
TransfoXLTokenizer
),
(
XLNetConfig
,
XLNetTokenizer
),
(
XLMConfig
,
XLMTokenizer
),
(
CTRLConfig
,
CTRLTokenizer
),
]
)
class
AutoTokenizer
(
object
):
r
""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
...
...
@@ -154,36 +176,13 @@ class AutoTokenizer(object):
if
"bert-base-japanese"
in
pretrained_model_name_or_path
:
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
isinstance
(
config
,
T5Config
):
return
T5Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
AlbertConfig
):
return
AlbertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
BertConfig
):
return
BertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
OpenAIGPTConfig
):
return
OpenAIGPTTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
GPT2Config
):
return
GPT2Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
TransfoXLConfig
):
return
TransfoXLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
for
config_class
,
tokenizer_class
in
TOKENIZER_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
tokenizer_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'"
.
format
(
pretrained_model_name_or_path
"Unrecognized configuration class {} to build an AutoTokenizer.
\n
"
"Model type should be one of {}."
.
format
(
config
.
__class__
,
", "
.
join
(
c
.
__name__
for
c
in
MODEL_MAPPING
.
keys
())
)
)
tests/test_configuration_auto.py
View file @
03046285
...
...
@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
# no key string should be included in a later key string (typical failure case)
keys
=
list
(
CONFIG_MAPPING
.
keys
())
for
i
,
key
in
enumerate
(
keys
):
self
.
assertFalse
(
any
(
key
in
later_key
for
later_key
in
keys
[
i
+
1
:]))
self
.
assertFalse
(
any
(
key
in
later_key
for
later_key
in
keys
[
i
+
1
:]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment