Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7d83655d
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "660e0b97bd652bd3a0dfd5f847e5cf62502d0469"
Unverified
Commit
7d83655d
authored
Oct 05, 2021
by
Sylvain Gugger
Committed by
GitHub
Oct 05, 2021
Browse files
Autodocument the list of ONNX-supported models (#13884)
parent
36fc4016
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
12 deletions
+66
-12
docs/source/serialization.rst
docs/source/serialization.rst
+7
-1
utils/check_table.py
utils/check_table.py
+59
-11
No files found.
docs/source/serialization.rst
View file @
7d83655d
...
@@ -37,12 +37,18 @@ architectures, and are made to be easily extendable to other architectures.
...
@@ -37,12 +37,18 @@ architectures, and are made to be easily extendable to other architectures.
Ready
-
made
configurations
include
the
following
models
:
Ready
-
made
configurations
include
the
following
models
:
..
This
table
is
automatically
generated
by
make
style
,
do
not
fill
manually
!
-
ALBERT
-
ALBERT
-
BART
-
BART
-
BERT
-
BERT
-
DistilBERT
-
DistilBERT
-
GPT
-
2
-
GPT
Neo
-
LayoutLM
-
LayoutLM
-
Longformer
-
mBART
-
OpenAI
GPT
-
2
-
RoBERTa
-
RoBERTa
-
T5
-
T5
-
XLM
-
RoBERTa
-
XLM
-
RoBERTa
...
...
utils/check_table.py
View file @
7d83655d
...
@@ -62,6 +62,15 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
...
@@ -62,6 +62,15 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
# This is to make sure the transformers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
"transformers"
,
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
)
transformers_module
=
spec
.
loader
.
load_module
()
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
def
camel_case_split
(
identifier
):
def
camel_case_split
(
identifier
):
"Split a camelcased `identifier` into words."
"Split a camelcased `identifier` into words."
...
@@ -78,19 +87,11 @@ def _center_text(text, width):
...
@@ -78,19 +87,11 @@ def _center_text(text, width):
def
get_model_table_from_auto_modules
():
def
get_model_table_from_auto_modules
():
"""Generates an up-to-date model table from the content of the auto modules."""
"""Generates an up-to-date model table from the content of the auto modules."""
# This is to make sure the transformers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
"transformers"
,
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
)
transformers
=
spec
.
loader
.
load_module
()
# Dictionary model names to config.
# Dictionary model names to config.
config_maping_names
=
transformers
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING_NAMES
config_maping_names
=
transformers
_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING_NAMES
model_name_to_config
=
{
model_name_to_config
=
{
name
:
config_maping_names
[
code
]
name
:
config_maping_names
[
code
]
for
code
,
name
in
transformers
.
MODEL_NAMES_MAPPING
.
items
()
for
code
,
name
in
transformers
_module
.
MODEL_NAMES_MAPPING
.
items
()
if
code
in
config_maping_names
if
code
in
config_maping_names
}
}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"Config"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"Config"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
...
@@ -103,7 +104,7 @@ def get_model_table_from_auto_modules():
...
@@ -103,7 +104,7 @@ def get_model_table_from_auto_modules():
flax_models
=
collections
.
defaultdict
(
bool
)
flax_models
=
collections
.
defaultdict
(
bool
)
# Let's lookup through all transformers object (once).
# Let's lookup through all transformers object (once).
for
attr_name
in
dir
(
transformers
):
for
attr_name
in
dir
(
transformers
_module
):
lookup_dict
=
None
lookup_dict
=
None
if
attr_name
.
endswith
(
"Tokenizer"
):
if
attr_name
.
endswith
(
"Tokenizer"
):
lookup_dict
=
slow_tokenizers
lookup_dict
=
slow_tokenizers
...
@@ -178,9 +179,56 @@ def check_model_table(overwrite=False):
...
@@ -178,9 +179,56 @@ def check_model_table(overwrite=False):
)
)
def
has_onnx
(
model_type
):
"""
Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not.
"""
config_mapping
=
transformers_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING
if
model_type
not
in
config_mapping
:
return
False
config
=
config_mapping
[
model_type
]
config_module
=
config
.
__module__
module
=
transformers_module
for
part
in
config_module
.
split
(
"."
)[
1
:]:
module
=
getattr
(
module
,
part
)
config_name
=
config
.
__name__
onnx_config_name
=
config_name
.
replace
(
"Config"
,
"OnnxConfig"
)
return
hasattr
(
module
,
onnx_config_name
)
def
get_onnx_model_list
():
"""
Return the list of models supporting ONNX.
"""
config_mapping
=
transformers_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING
model_names
=
config_mapping
=
transformers_module
.
models
.
auto
.
configuration_auto
.
MODEL_NAMES_MAPPING
onnx_model_types
=
[
model_type
for
model_type
in
config_mapping
.
keys
()
if
has_onnx
(
model_type
)]
onnx_model_names
=
[
model_names
[
model_type
]
for
model_type
in
onnx_model_types
]
onnx_model_names
.
sort
(
key
=
lambda
x
:
x
.
upper
())
return
"
\n
"
.
join
([
f
"-
{
name
}
"
for
name
in
onnx_model_names
])
+
"
\n
"
def
check_onnx_model_list
(
overwrite
=
False
):
"""Check the model list in the serialization.rst is consistent with the state of the lib and maybe `overwrite`."""
current_list
,
start_index
,
end_index
,
lines
=
_find_text_in_file
(
filename
=
os
.
path
.
join
(
PATH_TO_DOCS
,
"serialization.rst"
),
start_prompt
=
" This table is automatically generated by make style, do not fill manually!"
,
end_prompt
=
"This conversion is handled with the PyTorch version of models "
,
)
new_list
=
get_onnx_model_list
()
if
current_list
!=
new_list
:
if
overwrite
:
with
open
(
os
.
path
.
join
(
PATH_TO_DOCS
,
"serialization.rst"
),
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
f
.
writelines
(
lines
[:
start_index
]
+
[
new_list
]
+
lines
[
end_index
:])
else
:
raise
ValueError
(
"The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this."
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--fix_and_overwrite"
,
action
=
"store_true"
,
help
=
"Whether to fix inconsistencies."
)
parser
.
add_argument
(
"--fix_and_overwrite"
,
action
=
"store_true"
,
help
=
"Whether to fix inconsistencies."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
check_model_table
(
args
.
fix_and_overwrite
)
check_model_table
(
args
.
fix_and_overwrite
)
check_onnx_model_list
(
args
.
fix_and_overwrite
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment