Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2ae7388e
Unverified
Commit
2ae7388e
authored
Dec 07, 2020
by
Lysandre Debut
Committed by
GitHub
Dec 07, 2020
Browse files
Check table as independent script (#8976)
parent
00aa9dbc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
188 additions
and
127 deletions
+188
-127
.circleci/config.yml
.circleci/config.yml
+1
-0
Makefile
Makefile
+2
-0
utils/check_copies.py
utils/check_copies.py
+0
-127
utils/check_table.py
utils/check_table.py
+185
-0
No files found.
.circleci/config.yml
View file @
2ae7388e
...
@@ -381,6 +381,7 @@ jobs:
...
@@ -381,6 +381,7 @@ jobs:
-
run
:
flake8 examples tests src utils
-
run
:
flake8 examples tests src utils
-
run
:
python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
-
run
:
python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
-
run
:
python utils/check_copies.py
-
run
:
python utils/check_copies.py
-
run
:
python utils/check_table.py
-
run
:
python utils/check_dummies.py
-
run
:
python utils/check_dummies.py
-
run
:
python utils/check_repo.py
-
run
:
python utils/check_repo.py
...
...
Makefile
View file @
2ae7388e
...
@@ -23,6 +23,7 @@ deps_table_update:
...
@@ -23,6 +23,7 @@ deps_table_update:
extra_quality_checks
:
deps_table_update
extra_quality_checks
:
deps_table_update
python utils/check_copies.py
python utils/check_copies.py
python utils/check_table.py
python utils/check_dummies.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_repo.py
python utils/style_doc.py src/transformers docs/source
--max_len
119
python utils/style_doc.py src/transformers docs/source
--max_len
119
...
@@ -50,6 +51,7 @@ fixup: modified_only_fixup extra_quality_checks
...
@@ -50,6 +51,7 @@ fixup: modified_only_fixup extra_quality_checks
fix-copies
:
fix-copies
:
python utils/check_copies.py
--fix_and_overwrite
python utils/check_copies.py
--fix_and_overwrite
python utils/check_table.py
--fix_and_overwrite
python utils/check_dummies.py
--fix_and_overwrite
python utils/check_dummies.py
--fix_and_overwrite
# Run tests for the library
# Run tests for the library
...
...
utils/check_copies.py
View file @
2ae7388e
...
@@ -14,9 +14,7 @@
...
@@ -14,9 +14,7 @@
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
collections
import
glob
import
glob
import
importlib.util
import
os
import
os
import
re
import
re
import
tempfile
import
tempfile
...
@@ -299,134 +297,9 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
...
@@ -299,134 +297,9 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
)
# Add here suffixes that are used to identify models, seperated by |
ALLOWED_MODEL_SUFFIXES
=
"Model|Encoder|Decoder|ForConditionalGeneration"
# Regexes that match TF/Flax/PT model names.
_re_tf_models
=
re
.
compile
(
r
"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
_re_flax_models
=
re
.
compile
(
r
"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
def
camel_case_split
(
identifier
):
"Split a camelcased `identifier` into words."
matches
=
re
.
finditer
(
".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)"
,
identifier
)
return
[
m
.
group
(
0
)
for
m
in
matches
]
def
_center_text
(
text
,
width
):
text_length
=
2
if
text
==
"✅"
or
text
==
"❌"
else
len
(
text
)
left_indent
=
(
width
-
text_length
)
//
2
right_indent
=
width
-
text_length
-
left_indent
return
" "
*
left_indent
+
text
+
" "
*
right_indent
def
get_model_table_from_auto_modules
():
"""Generates an up-to-date model table from the content of the auto modules."""
# This is to make sure the transformers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
"transformers"
,
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
)
transformers
=
spec
.
loader
.
load_module
()
# Dictionary model names to config.
model_name_to_config
=
{
name
:
transformers
.
CONFIG_MAPPING
[
code
]
for
code
,
name
in
transformers
.
MODEL_NAMES_MAPPING
.
items
()
}
model_name_to_prefix
=
{
name
:
config
.
__name__
.
replace
(
"Config"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()
}
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers
=
collections
.
defaultdict
(
bool
)
fast_tokenizers
=
collections
.
defaultdict
(
bool
)
pt_models
=
collections
.
defaultdict
(
bool
)
tf_models
=
collections
.
defaultdict
(
bool
)
flax_models
=
collections
.
defaultdict
(
bool
)
# Let's lookup through all transformers object (once).
for
attr_name
in
dir
(
transformers
):
lookup_dict
=
None
if
attr_name
.
endswith
(
"Tokenizer"
):
lookup_dict
=
slow_tokenizers
attr_name
=
attr_name
[:
-
9
]
elif
attr_name
.
endswith
(
"TokenizerFast"
):
lookup_dict
=
fast_tokenizers
attr_name
=
attr_name
[:
-
13
]
elif
_re_tf_models
.
match
(
attr_name
)
is
not
None
:
lookup_dict
=
tf_models
attr_name
=
_re_tf_models
.
match
(
attr_name
).
groups
()[
0
]
elif
_re_flax_models
.
match
(
attr_name
)
is
not
None
:
lookup_dict
=
flax_models
attr_name
=
_re_flax_models
.
match
(
attr_name
).
groups
()[
0
]
elif
_re_pt_models
.
match
(
attr_name
)
is
not
None
:
lookup_dict
=
pt_models
attr_name
=
_re_pt_models
.
match
(
attr_name
).
groups
()[
0
]
if
lookup_dict
is
not
None
:
while
len
(
attr_name
)
>
0
:
if
attr_name
in
model_name_to_prefix
.
values
():
lookup_dict
[
attr_name
]
=
True
break
# Try again after removing the last word in the name
attr_name
=
""
.
join
(
camel_case_split
(
attr_name
)[:
-
1
])
# Let's build that table!
model_names
=
list
(
model_name_to_config
.
keys
())
model_names
.
sort
()
columns
=
[
"Model"
,
"Tokenizer slow"
,
"Tokenizer fast"
,
"PyTorch support"
,
"TensorFlow support"
,
"Flax Support"
]
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
widths
=
[
len
(
c
)
+
2
for
c
in
columns
]
widths
[
0
]
=
max
([
len
(
name
)
for
name
in
model_names
])
+
2
# Rst table per se
table
=
".. rst-class:: center-aligned-table
\n\n
"
table
+=
"+"
+
"+"
.
join
([
"-"
*
w
for
w
in
widths
])
+
"+
\n
"
table
+=
"|"
+
"|"
.
join
([
_center_text
(
c
,
w
)
for
c
,
w
in
zip
(
columns
,
widths
)])
+
"|
\n
"
table
+=
"+"
+
"+"
.
join
([
"="
*
w
for
w
in
widths
])
+
"+
\n
"
check
=
{
True
:
"✅"
,
False
:
"❌"
}
for
name
in
model_names
:
prefix
=
model_name_to_prefix
[
name
]
line
=
[
name
,
check
[
slow_tokenizers
[
prefix
]],
check
[
fast_tokenizers
[
prefix
]],
check
[
pt_models
[
prefix
]],
check
[
tf_models
[
prefix
]],
check
[
flax_models
[
prefix
]],
]
table
+=
"|"
+
"|"
.
join
([
_center_text
(
l
,
w
)
for
l
,
w
in
zip
(
line
,
widths
)])
+
"|
\n
"
table
+=
"+"
+
"+"
.
join
([
"-"
*
w
for
w
in
widths
])
+
"+
\n
"
return
table
def
check_model_table
(
overwrite
=
False
):
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
current_table
,
start_index
,
end_index
,
lines
=
_find_text_in_file
(
filename
=
os
.
path
.
join
(
PATH_TO_DOCS
,
"index.rst"
),
start_prompt
=
" This table is updated automatically from the auto module"
,
end_prompt
=
".. toctree::"
,
)
new_table
=
get_model_table_from_auto_modules
()
if
current_table
!=
new_table
:
if
overwrite
:
with
open
(
os
.
path
.
join
(
PATH_TO_DOCS
,
"index.rst"
),
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
f
.
writelines
(
lines
[:
start_index
]
+
[
new_table
]
+
lines
[
end_index
:])
else
:
raise
ValueError
(
"The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this."
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--fix_and_overwrite"
,
action
=
"store_true"
,
help
=
"Whether to fix inconsistencies."
)
parser
.
add_argument
(
"--fix_and_overwrite"
,
action
=
"store_true"
,
help
=
"Whether to fix inconsistencies."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
check_copies
(
args
.
fix_and_overwrite
)
check_copies
(
args
.
fix_and_overwrite
)
check_model_table
(
args
.
fix_and_overwrite
)
utils/check_table.py
0 → 100644
View file @
2ae7388e
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
collections
import
importlib.util
import
os
import
re
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_table.py
TRANSFORMERS_PATH
=
"src/transformers"
PATH_TO_DOCS
=
"docs/source"
REPO_PATH
=
"."
def
_find_text_in_file
(
filename
,
start_prompt
,
end_prompt
):
"""
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
lines.
"""
with
open
(
filename
,
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
lines
=
f
.
readlines
()
# Find the start prompt.
start_index
=
0
while
not
lines
[
start_index
].
startswith
(
start_prompt
):
start_index
+=
1
start_index
+=
1
end_index
=
start_index
while
not
lines
[
end_index
].
startswith
(
end_prompt
):
end_index
+=
1
end_index
-=
1
while
len
(
lines
[
start_index
])
<=
1
:
start_index
+=
1
while
len
(
lines
[
end_index
])
<=
1
:
end_index
-=
1
end_index
+=
1
return
""
.
join
(
lines
[
start_index
:
end_index
]),
start_index
,
end_index
,
lines
# Add here suffixes that are used to identify models, seperated by |
ALLOWED_MODEL_SUFFIXES
=
"Model|Encoder|Decoder|ForConditionalGeneration"
# Regexes that match TF/Flax/PT model names.
_re_tf_models
=
re
.
compile
(
r
"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
_re_flax_models
=
re
.
compile
(
r
"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
def
camel_case_split
(
identifier
):
"Split a camelcased `identifier` into words."
matches
=
re
.
finditer
(
".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)"
,
identifier
)
return
[
m
.
group
(
0
)
for
m
in
matches
]
def
_center_text
(
text
,
width
):
text_length
=
2
if
text
==
"✅"
or
text
==
"❌"
else
len
(
text
)
left_indent
=
(
width
-
text_length
)
//
2
right_indent
=
width
-
text_length
-
left_indent
return
" "
*
left_indent
+
text
+
" "
*
right_indent
def
get_model_table_from_auto_modules
():
"""Generates an up-to-date model table from the content of the auto modules."""
# This is to make sure the transformers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
"transformers"
,
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
)
transformers
=
spec
.
loader
.
load_module
()
# Dictionary model names to config.
model_name_to_config
=
{
name
:
transformers
.
CONFIG_MAPPING
[
code
]
for
code
,
name
in
transformers
.
MODEL_NAMES_MAPPING
.
items
()
}
model_name_to_prefix
=
{
name
:
config
.
__name__
.
replace
(
"Config"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()
}
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers
=
collections
.
defaultdict
(
bool
)
fast_tokenizers
=
collections
.
defaultdict
(
bool
)
pt_models
=
collections
.
defaultdict
(
bool
)
tf_models
=
collections
.
defaultdict
(
bool
)
flax_models
=
collections
.
defaultdict
(
bool
)
# Let's lookup through all transformers object (once).
for
attr_name
in
dir
(
transformers
):
lookup_dict
=
None
if
attr_name
.
endswith
(
"Tokenizer"
):
lookup_dict
=
slow_tokenizers
attr_name
=
attr_name
[:
-
9
]
elif
attr_name
.
endswith
(
"TokenizerFast"
):
lookup_dict
=
fast_tokenizers
attr_name
=
attr_name
[:
-
13
]
elif
_re_tf_models
.
match
(
attr_name
)
is
not
None
:
lookup_dict
=
tf_models
attr_name
=
_re_tf_models
.
match
(
attr_name
).
groups
()[
0
]
elif
_re_flax_models
.
match
(
attr_name
)
is
not
None
:
lookup_dict
=
flax_models
attr_name
=
_re_flax_models
.
match
(
attr_name
).
groups
()[
0
]
elif
_re_pt_models
.
match
(
attr_name
)
is
not
None
:
lookup_dict
=
pt_models
attr_name
=
_re_pt_models
.
match
(
attr_name
).
groups
()[
0
]
if
lookup_dict
is
not
None
:
while
len
(
attr_name
)
>
0
:
if
attr_name
in
model_name_to_prefix
.
values
():
lookup_dict
[
attr_name
]
=
True
break
# Try again after removing the last word in the name
attr_name
=
""
.
join
(
camel_case_split
(
attr_name
)[:
-
1
])
# Let's build that table!
model_names
=
list
(
model_name_to_config
.
keys
())
model_names
.
sort
()
columns
=
[
"Model"
,
"Tokenizer slow"
,
"Tokenizer fast"
,
"PyTorch support"
,
"TensorFlow support"
,
"Flax Support"
]
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
widths
=
[
len
(
c
)
+
2
for
c
in
columns
]
widths
[
0
]
=
max
([
len
(
name
)
for
name
in
model_names
])
+
2
# Rst table per se
table
=
".. rst-class:: center-aligned-table
\n\n
"
table
+=
"+"
+
"+"
.
join
([
"-"
*
w
for
w
in
widths
])
+
"+
\n
"
table
+=
"|"
+
"|"
.
join
([
_center_text
(
c
,
w
)
for
c
,
w
in
zip
(
columns
,
widths
)])
+
"|
\n
"
table
+=
"+"
+
"+"
.
join
([
"="
*
w
for
w
in
widths
])
+
"+
\n
"
check
=
{
True
:
"✅"
,
False
:
"❌"
}
for
name
in
model_names
:
prefix
=
model_name_to_prefix
[
name
]
line
=
[
name
,
check
[
slow_tokenizers
[
prefix
]],
check
[
fast_tokenizers
[
prefix
]],
check
[
pt_models
[
prefix
]],
check
[
tf_models
[
prefix
]],
check
[
flax_models
[
prefix
]],
]
table
+=
"|"
+
"|"
.
join
([
_center_text
(
l
,
w
)
for
l
,
w
in
zip
(
line
,
widths
)])
+
"|
\n
"
table
+=
"+"
+
"+"
.
join
([
"-"
*
w
for
w
in
widths
])
+
"+
\n
"
return
table
def
check_model_table
(
overwrite
=
False
):
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
current_table
,
start_index
,
end_index
,
lines
=
_find_text_in_file
(
filename
=
os
.
path
.
join
(
PATH_TO_DOCS
,
"index.rst"
),
start_prompt
=
" This table is updated automatically from the auto module"
,
end_prompt
=
".. toctree::"
,
)
new_table
=
get_model_table_from_auto_modules
()
if
current_table
!=
new_table
:
if
overwrite
:
with
open
(
os
.
path
.
join
(
PATH_TO_DOCS
,
"index.rst"
),
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
f
.
writelines
(
lines
[:
start_index
]
+
[
new_table
]
+
lines
[
end_index
:])
else
:
raise
ValueError
(
"The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--fix_and_overwrite"
,
action
=
"store_true"
,
help
=
"Whether to fix inconsistencies."
)
args
=
parser
.
parse_args
()
check_model_table
(
args
.
fix_and_overwrite
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment