Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
32dbb2d9
Unverified
Commit
32dbb2d9
authored
Apr 26, 2021
by
Patrick von Platen
Committed by
GitHub
Apr 26, 2021
Browse files
make style (#11442)
parent
04ab2ca6
Changes
105
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
25 deletions
+25
-25
utils/check_copies.py
utils/check_copies.py
+5
-5
utils/check_dummies.py
utils/check_dummies.py
+4
-4
utils/check_repo.py
utils/check_repo.py
+14
-14
utils/check_table.py
utils/check_table.py
+1
-1
utils/style_doc.py
utils/style_doc.py
+1
-1
No files found.
utils/check_copies.py
View file @
32dbb2d9
...
@@ -33,7 +33,7 @@ def _should_continue(line, indent):
...
@@ -33,7 +33,7 @@ def _should_continue(line, indent):
def
find_code_in_transformers
(
object_name
):
def
find_code_in_transformers
(
object_name
):
"""
Find and return the code source code of `object_name`."""
"""Find and return the code source code of `object_name`."""
parts
=
object_name
.
split
(
"."
)
parts
=
object_name
.
split
(
"."
)
i
=
0
i
=
0
...
@@ -193,7 +193,7 @@ def check_copies(overwrite: bool = False):
...
@@ -193,7 +193,7 @@ def check_copies(overwrite: bool = False):
def
get_model_list
():
def
get_model_list
():
"""
Extracts the model list from the README.
"""
"""Extracts the model list from the README."""
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
_start_prompt
=
"🤗 Transformers currently provides the following architectures"
_start_prompt
=
"🤗 Transformers currently provides the following architectures"
_end_prompt
=
"1. Want to contribute a new model?"
_end_prompt
=
"1. Want to contribute a new model?"
...
@@ -224,7 +224,7 @@ def get_model_list():
...
@@ -224,7 +224,7 @@ def get_model_list():
def
split_long_line_with_indent
(
line
,
max_per_line
,
indent
):
def
split_long_line_with_indent
(
line
,
max_per_line
,
indent
):
"""
Split the `line` so that it doesn't go over `max_per_line` and adds `indent` to new lines.
"""
"""Split the `line` so that it doesn't go over `max_per_line` and adds `indent` to new lines."""
words
=
line
.
split
(
" "
)
words
=
line
.
split
(
" "
)
lines
=
[]
lines
=
[]
current_line
=
words
[
0
]
current_line
=
words
[
0
]
...
@@ -239,7 +239,7 @@ def split_long_line_with_indent(line, max_per_line, indent):
...
@@ -239,7 +239,7 @@ def split_long_line_with_indent(line, max_per_line, indent):
def
convert_to_rst
(
model_list
,
max_per_line
=
None
):
def
convert_to_rst
(
model_list
,
max_per_line
=
None
):
"""
Convert `model_list` to rst format.
"""
"""Convert `model_list` to rst format."""
# Convert **[description](link)** to `description <link>`__
# Convert **[description](link)** to `description <link>`__
def
_rep_link
(
match
):
def
_rep_link
(
match
):
title
,
link
=
match
.
groups
()
title
,
link
=
match
.
groups
()
...
@@ -298,7 +298,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
...
@@ -298,7 +298,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
def
check_model_list_copy
(
overwrite
=
False
,
max_per_line
=
119
):
def
check_model_list_copy
(
overwrite
=
False
,
max_per_line
=
119
):
"""
Check the model lists in the README and index.rst are consistent and maybe `overwrite`.
"""
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
rst_list
,
start_index
,
end_index
,
lines
=
_find_text_in_file
(
rst_list
,
start_index
,
end_index
,
lines
=
_find_text_in_file
(
filename
=
os
.
path
.
join
(
PATH_TO_DOCS
,
"index.rst"
),
filename
=
os
.
path
.
join
(
PATH_TO_DOCS
,
"index.rst"
),
start_prompt
=
" This list is updated automatically from the README"
,
start_prompt
=
" This list is updated automatically from the README"
,
...
...
utils/check_dummies.py
View file @
32dbb2d9
...
@@ -65,7 +65,7 @@ def find_backend(line):
...
@@ -65,7 +65,7 @@ def find_backend(line):
def
read_init
():
def
read_init
():
"""
Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects.
"""
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
with
open
(
os
.
path
.
join
(
PATH_TO_TRANSFORMERS
,
"__init__.py"
),
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
os
.
path
.
join
(
PATH_TO_TRANSFORMERS
,
"__init__.py"
),
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
...
@@ -101,7 +101,7 @@ def read_init():
...
@@ -101,7 +101,7 @@ def read_init():
def
create_dummy_object
(
name
,
backend_name
):
def
create_dummy_object
(
name
,
backend_name
):
"""
Create the code for the dummy object corresponding to `name`."""
"""Create the code for the dummy object corresponding to `name`."""
_pretrained
=
[
_pretrained
=
[
"Config"
"ForCausalLM"
,
"Config"
"ForCausalLM"
,
"ForConditionalGeneration"
,
"ForConditionalGeneration"
,
...
@@ -130,7 +130,7 @@ def create_dummy_object(name, backend_name):
...
@@ -130,7 +130,7 @@ def create_dummy_object(name, backend_name):
def
create_dummy_files
():
def
create_dummy_files
():
"""
Create the content of the dummy files.
"""
"""Create the content of the dummy files."""
backend_specific_objects
=
read_init
()
backend_specific_objects
=
read_init
()
# For special correspondence backend to module name as used in the function requires_modulename
# For special correspondence backend to module name as used in the function requires_modulename
dummy_files
=
{}
dummy_files
=
{}
...
@@ -146,7 +146,7 @@ def create_dummy_files():
...
@@ -146,7 +146,7 @@ def create_dummy_files():
def
check_dummies
(
overwrite
=
False
):
def
check_dummies
(
overwrite
=
False
):
"""
Check if the dummy files are up to date and maybe `overwrite` with the right content.
"""
"""Check if the dummy files are up to date and maybe `overwrite` with the right content."""
dummy_files
=
create_dummy_files
()
dummy_files
=
create_dummy_files
()
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
short_names
=
{
"torch"
:
"pt"
}
short_names
=
{
"torch"
:
"pt"
}
...
...
utils/check_repo.py
View file @
32dbb2d9
...
@@ -119,7 +119,7 @@ transformers = spec.loader.load_module()
...
@@ -119,7 +119,7 @@ transformers = spec.loader.load_module()
# If some modeling modules should be ignored for all checks, they should be added in the nested list
# If some modeling modules should be ignored for all checks, they should be added in the nested list
# _ignore_modules of this function.
# _ignore_modules of this function.
def
get_model_modules
():
def
get_model_modules
():
"""
Get the model modules inside the transformers library.
"""
"""Get the model modules inside the transformers library."""
_ignore_modules
=
[
_ignore_modules
=
[
"modeling_auto"
,
"modeling_auto"
,
"modeling_encoder_decoder"
,
"modeling_encoder_decoder"
,
...
@@ -151,7 +151,7 @@ def get_model_modules():
...
@@ -151,7 +151,7 @@ def get_model_modules():
def
get_models
(
module
):
def
get_models
(
module
):
"""
Get the objects in module that are models."""
"""Get the objects in module that are models."""
models
=
[]
models
=
[]
model_classes
=
(
transformers
.
PreTrainedModel
,
transformers
.
TFPreTrainedModel
,
transformers
.
FlaxPreTrainedModel
)
model_classes
=
(
transformers
.
PreTrainedModel
,
transformers
.
TFPreTrainedModel
,
transformers
.
FlaxPreTrainedModel
)
for
attr_name
in
dir
(
module
):
for
attr_name
in
dir
(
module
):
...
@@ -166,7 +166,7 @@ def get_models(module):
...
@@ -166,7 +166,7 @@ def get_models(module):
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
# nested list _ignore_files of this function.
def
get_model_test_files
():
def
get_model_test_files
():
"""
Get the model test files."""
"""Get the model test files."""
_ignore_files
=
[
_ignore_files
=
[
"test_modeling_common"
,
"test_modeling_common"
,
"test_modeling_encoder_decoder"
,
"test_modeling_encoder_decoder"
,
...
@@ -187,7 +187,7 @@ def get_model_test_files():
...
@@ -187,7 +187,7 @@ def get_model_test_files():
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
# for the all_model_classes variable.
# for the all_model_classes variable.
def
find_tested_models
(
test_file
):
def
find_tested_models
(
test_file
):
"""
Parse the content of test_file to detect what's in all_model_classes"""
"""Parse the content of test_file to detect what's in all_model_classes"""
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
with
open
(
os
.
path
.
join
(
PATH_TO_TESTS
,
test_file
),
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
os
.
path
.
join
(
PATH_TO_TESTS
,
test_file
),
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
content
=
f
.
read
()
content
=
f
.
read
()
...
@@ -205,7 +205,7 @@ def find_tested_models(test_file):
...
@@ -205,7 +205,7 @@ def find_tested_models(test_file):
def
check_models_are_tested
(
module
,
test_file
):
def
check_models_are_tested
(
module
,
test_file
):
"""
Check models defined in module are tested in test_file."""
"""Check models defined in module are tested in test_file."""
defined_models
=
get_models
(
module
)
defined_models
=
get_models
(
module
)
tested_models
=
find_tested_models
(
test_file
)
tested_models
=
find_tested_models
(
test_file
)
if
tested_models
is
None
:
if
tested_models
is
None
:
...
@@ -229,7 +229,7 @@ def check_models_are_tested(module, test_file):
...
@@ -229,7 +229,7 @@ def check_models_are_tested(module, test_file):
def
check_all_models_are_tested
():
def
check_all_models_are_tested
():
"""
Check all models are properly tested."""
"""Check all models are properly tested."""
modules
=
get_model_modules
()
modules
=
get_model_modules
()
test_files
=
get_model_test_files
()
test_files
=
get_model_test_files
()
failures
=
[]
failures
=
[]
...
@@ -245,7 +245,7 @@ def check_all_models_are_tested():
...
@@ -245,7 +245,7 @@ def check_all_models_are_tested():
def
get_all_auto_configured_models
():
def
get_all_auto_configured_models
():
"""
Return the list of all models in at least one auto class."""
"""Return the list of all models in at least one auto class."""
result
=
set
()
# To avoid duplicates we concatenate all model classes in a set.
result
=
set
()
# To avoid duplicates we concatenate all model classes in a set.
for
attr_name
in
dir
(
transformers
.
models
.
auto
.
modeling_auto
):
for
attr_name
in
dir
(
transformers
.
models
.
auto
.
modeling_auto
):
if
attr_name
.
startswith
(
"MODEL_"
)
and
attr_name
.
endswith
(
"MAPPING"
):
if
attr_name
.
startswith
(
"MODEL_"
)
and
attr_name
.
endswith
(
"MAPPING"
):
...
@@ -271,7 +271,7 @@ def ignore_unautoclassed(model_name):
...
@@ -271,7 +271,7 @@ def ignore_unautoclassed(model_name):
def
check_models_are_auto_configured
(
module
,
all_auto_models
):
def
check_models_are_auto_configured
(
module
,
all_auto_models
):
"""
Check models defined in module are each in an auto class."""
"""Check models defined in module are each in an auto class."""
defined_models
=
get_models
(
module
)
defined_models
=
get_models
(
module
)
failures
=
[]
failures
=
[]
for
model_name
,
_
in
defined_models
:
for
model_name
,
_
in
defined_models
:
...
@@ -285,7 +285,7 @@ def check_models_are_auto_configured(module, all_auto_models):
...
@@ -285,7 +285,7 @@ def check_models_are_auto_configured(module, all_auto_models):
def
check_all_models_are_auto_configured
():
def
check_all_models_are_auto_configured
():
"""
Check all models are each in an auto class."""
"""Check all models are each in an auto class."""
modules
=
get_model_modules
()
modules
=
get_model_modules
()
all_auto_models
=
get_all_auto_configured_models
()
all_auto_models
=
get_all_auto_configured_models
()
failures
=
[]
failures
=
[]
...
@@ -301,7 +301,7 @@ _re_decorator = re.compile(r"^\s*@(\S+)\s+$")
...
@@ -301,7 +301,7 @@ _re_decorator = re.compile(r"^\s*@(\S+)\s+$")
def
check_decorator_order
(
filename
):
def
check_decorator_order
(
filename
):
"""
Check that in the test file `filename` the slow decorator is always last."""
"""Check that in the test file `filename` the slow decorator is always last."""
with
open
(
filename
,
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
filename
,
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
decorator_before
=
None
decorator_before
=
None
...
@@ -319,7 +319,7 @@ def check_decorator_order(filename):
...
@@ -319,7 +319,7 @@ def check_decorator_order(filename):
def
check_all_decorator_order
():
def
check_all_decorator_order
():
"""
Check that in all test files, the slow decorator is always last."""
"""Check that in all test files, the slow decorator is always last."""
errors
=
[]
errors
=
[]
for
fname
in
os
.
listdir
(
PATH_TO_TESTS
):
for
fname
in
os
.
listdir
(
PATH_TO_TESTS
):
if
fname
.
endswith
(
".py"
):
if
fname
.
endswith
(
".py"
):
...
@@ -334,7 +334,7 @@ def check_all_decorator_order():
...
@@ -334,7 +334,7 @@ def check_all_decorator_order():
def
find_all_documented_objects
():
def
find_all_documented_objects
():
"""
Parse the content of all doc files to detect which classes and functions it documents"""
"""Parse the content of all doc files to detect which classes and functions it documents"""
documented_obj
=
[]
documented_obj
=
[]
for
doc_file
in
Path
(
PATH_TO_DOC
).
glob
(
"**/*.rst"
):
for
doc_file
in
Path
(
PATH_TO_DOC
).
glob
(
"**/*.rst"
):
with
open
(
doc_file
,
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
doc_file
,
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
...
@@ -454,7 +454,7 @@ def ignore_undocumented(name):
...
@@ -454,7 +454,7 @@ def ignore_undocumented(name):
def
check_all_objects_are_documented
():
def
check_all_objects_are_documented
():
"""
Check all models are properly documented."""
"""Check all models are properly documented."""
documented_objs
=
find_all_documented_objects
()
documented_objs
=
find_all_documented_objects
()
modules
=
transformers
.
_modules
modules
=
transformers
.
_modules
objects
=
[
c
for
c
in
dir
(
transformers
)
if
c
not
in
modules
and
not
c
.
startswith
(
"_"
)]
objects
=
[
c
for
c
in
dir
(
transformers
)
if
c
not
in
modules
and
not
c
.
startswith
(
"_"
)]
...
@@ -467,7 +467,7 @@ def check_all_objects_are_documented():
...
@@ -467,7 +467,7 @@ def check_all_objects_are_documented():
def
check_repo_quality
():
def
check_repo_quality
():
"""
Check all models are properly tested and documented."""
"""Check all models are properly tested and documented."""
print
(
"Checking all models are properly tested."
)
print
(
"Checking all models are properly tested."
)
check_all_decorator_order
()
check_all_decorator_order
()
check_all_models_are_tested
()
check_all_models_are_tested
()
...
...
utils/check_table.py
View file @
32dbb2d9
...
@@ -159,7 +159,7 @@ def get_model_table_from_auto_modules():
...
@@ -159,7 +159,7 @@ def get_model_table_from_auto_modules():
def
check_model_table
(
overwrite
=
False
):
def
check_model_table
(
overwrite
=
False
):
"""
Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`.
"""
"""Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
current_table
,
start_index
,
end_index
,
lines
=
_find_text_in_file
(
current_table
,
start_index
,
end_index
,
lines
=
_find_text_in_file
(
filename
=
os
.
path
.
join
(
PATH_TO_DOCS
,
"index.rst"
),
filename
=
os
.
path
.
join
(
PATH_TO_DOCS
,
"index.rst"
),
start_prompt
=
" This table is updated automatically from the auto module"
,
start_prompt
=
" This table is updated automatically from the auto module"
,
...
...
utils/style_doc.py
View file @
32dbb2d9
...
@@ -431,7 +431,7 @@ def _add_new_lines_before_doc_special_words(text):
...
@@ -431,7 +431,7 @@ def _add_new_lines_before_doc_special_words(text):
def
style_rst_file
(
doc_file
,
max_len
=
119
,
check_only
=
False
):
def
style_rst_file
(
doc_file
,
max_len
=
119
,
check_only
=
False
):
"""
Style one rst file `doc_file` to `max_len`."""
"""Style one rst file `doc_file` to `max_len`."""
with
open
(
doc_file
,
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
doc_file
,
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
doc
=
f
.
read
()
doc
=
f
.
read
()
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment