Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7d83655d
Unverified
Commit
7d83655d
authored
Oct 05, 2021
by
Sylvain Gugger
Committed by
GitHub
Oct 05, 2021
Browse files
Autodocument the list of ONNX-supported models (#13884)
parent
36fc4016
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
12 deletions
+66
-12
docs/source/serialization.rst
docs/source/serialization.rst
+7
-1
utils/check_table.py
utils/check_table.py
+59
-11
No files found.
docs/source/serialization.rst
View file @
7d83655d
...
@@ -37,12 +37,18 @@ architectures, and are made to be easily extendable to other architectures.
...
@@ -37,12 +37,18 @@ architectures, and are made to be easily extendable to other architectures.
Ready
-
made
configurations
include
the
following
models
:
Ready
-
made
configurations
include
the
following
models
:
..
This
table
is
automatically
generated
by
make
style
,
do
not
fill
manually
!
-
ALBERT
-
ALBERT
-
BART
-
BART
-
BERT
-
BERT
-
DistilBERT
-
DistilBERT
-
GPT
-
2
-
GPT
Neo
-
LayoutLM
-
LayoutLM
-
Longformer
-
mBART
-
OpenAI
GPT
-
2
-
RoBERTa
-
RoBERTa
-
T5
-
T5
-
XLM
-
RoBERTa
-
XLM
-
RoBERTa
...
...
utils/check_table.py
View file @
7d83655d
...
@@ -62,6 +62,15 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
...
@@ -62,6 +62,15 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
# This is to make sure the transformers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
"transformers"
,
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
)
transformers_module
=
spec
.
loader
.
load_module
()
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
def
camel_case_split
(
identifier
):
def
camel_case_split
(
identifier
):
"Split a camelcased `identifier` into words."
"Split a camelcased `identifier` into words."
...
@@ -78,19 +87,11 @@ def _center_text(text, width):
...
@@ -78,19 +87,11 @@ def _center_text(text, width):
def
get_model_table_from_auto_modules
():
def
get_model_table_from_auto_modules
():
"""Generates an up-to-date model table from the content of the auto modules."""
"""Generates an up-to-date model table from the content of the auto modules."""
# This is to make sure the transformers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
"transformers"
,
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
)
transformers
=
spec
.
loader
.
load_module
()
# Dictionary model names to config.
# Dictionary model names to config.
config_maping_names
=
transformers
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING_NAMES
config_maping_names
=
transformers
_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING_NAMES
model_name_to_config
=
{
model_name_to_config
=
{
name
:
config_maping_names
[
code
]
name
:
config_maping_names
[
code
]
for
code
,
name
in
transformers
.
MODEL_NAMES_MAPPING
.
items
()
for
code
,
name
in
transformers
_module
.
MODEL_NAMES_MAPPING
.
items
()
if
code
in
config_maping_names
if
code
in
config_maping_names
}
}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"Config"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"Config"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
...
@@ -103,7 +104,7 @@ def get_model_table_from_auto_modules():
...
@@ -103,7 +104,7 @@ def get_model_table_from_auto_modules():
flax_models
=
collections
.
defaultdict
(
bool
)
flax_models
=
collections
.
defaultdict
(
bool
)
# Let's lookup through all transformers object (once).
# Let's lookup through all transformers object (once).
for
attr_name
in
dir
(
transformers
):
for
attr_name
in
dir
(
transformers
_module
):
lookup_dict
=
None
lookup_dict
=
None
if
attr_name
.
endswith
(
"Tokenizer"
):
if
attr_name
.
endswith
(
"Tokenizer"
):
lookup_dict
=
slow_tokenizers
lookup_dict
=
slow_tokenizers
...
@@ -178,9 +179,56 @@ def check_model_table(overwrite=False):
...
@@ -178,9 +179,56 @@ def check_model_table(overwrite=False):
)
)
def
has_onnx
(
model_type
):
"""
Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not.
"""
config_mapping
=
transformers_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING
if
model_type
not
in
config_mapping
:
return
False
config
=
config_mapping
[
model_type
]
config_module
=
config
.
__module__
module
=
transformers_module
for
part
in
config_module
.
split
(
"."
)[
1
:]:
module
=
getattr
(
module
,
part
)
config_name
=
config
.
__name__
onnx_config_name
=
config_name
.
replace
(
"Config"
,
"OnnxConfig"
)
return
hasattr
(
module
,
onnx_config_name
)
def
get_onnx_model_list
():
"""
Return the list of models supporting ONNX.
"""
config_mapping
=
transformers_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING
model_names
=
config_mapping
=
transformers_module
.
models
.
auto
.
configuration_auto
.
MODEL_NAMES_MAPPING
onnx_model_types
=
[
model_type
for
model_type
in
config_mapping
.
keys
()
if
has_onnx
(
model_type
)]
onnx_model_names
=
[
model_names
[
model_type
]
for
model_type
in
onnx_model_types
]
onnx_model_names
.
sort
(
key
=
lambda
x
:
x
.
upper
())
return
"
\n
"
.
join
([
f
"-
{
name
}
"
for
name
in
onnx_model_names
])
+
"
\n
"
def
check_onnx_model_list
(
overwrite
=
False
):
"""Check the model list in the serialization.rst is consistent with the state of the lib and maybe `overwrite`."""
current_list
,
start_index
,
end_index
,
lines
=
_find_text_in_file
(
filename
=
os
.
path
.
join
(
PATH_TO_DOCS
,
"serialization.rst"
),
start_prompt
=
" This table is automatically generated by make style, do not fill manually!"
,
end_prompt
=
"This conversion is handled with the PyTorch version of models "
,
)
new_list
=
get_onnx_model_list
()
if
current_list
!=
new_list
:
if
overwrite
:
with
open
(
os
.
path
.
join
(
PATH_TO_DOCS
,
"serialization.rst"
),
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
f
.
writelines
(
lines
[:
start_index
]
+
[
new_list
]
+
lines
[
end_index
:])
else
:
raise
ValueError
(
"The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this."
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--fix_and_overwrite"
,
action
=
"store_true"
,
help
=
"Whether to fix inconsistencies."
)
parser
.
add_argument
(
"--fix_and_overwrite"
,
action
=
"store_true"
,
help
=
"Whether to fix inconsistencies."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
check_model_table
(
args
.
fix_and_overwrite
)
check_model_table
(
args
.
fix_and_overwrite
)
check_onnx_model_list
(
args
.
fix_and_overwrite
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment