Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
23 additions
and
25 deletions
+23
-25
tests/trainer/test_trainer_utils.py
tests/trainer/test_trainer_utils.py
+4
-4
tests/utils/test_modeling_tf_core.py
tests/utils/test_modeling_tf_core.py
+2
-2
utils/check_copies.py
utils/check_copies.py
+1
-1
utils/check_doc_toc.py
utils/check_doc_toc.py
+1
-1
utils/check_repo.py
utils/check_repo.py
+2
-2
utils/create_dummy_models.py
utils/create_dummy_models.py
+3
-3
utils/extract_warnings.py
utils/extract_warnings.py
+1
-1
utils/get_ci_error_statistics.py
utils/get_ci_error_statistics.py
+1
-1
utils/tests_fetcher.py
utils/tests_fetcher.py
+7
-9
utils/update_metadata.py
utils/update_metadata.py
+1
-1
No files found.
tests/trainer/test_trainer_utils.py
View file @
5e8c8eb5
...
...
@@ -189,7 +189,7 @@ class TrainerUtilsTest(unittest.TestCase):
# The biggest element should be first
self
.
assertEqual
(
lengths
[
indices
[
0
]],
50
)
# The indices should be a permutation of range(100)
self
.
assertEqual
(
list
(
sorted
(
indices
)
)
,
list
(
range
(
100
)))
self
.
assertEqual
(
sorted
(
indices
),
list
(
range
(
100
)))
def
test_group_by_length_with_dict
(
self
):
# Get some inputs of random lengths
...
...
@@ -204,7 +204,7 @@ class TrainerUtilsTest(unittest.TestCase):
# The biggest element should be first
self
.
assertEqual
(
len
(
data
[
indices
[
0
]][
"input_ids"
]),
105
)
# The indices should be a permutation of range(6)
self
.
assertEqual
(
list
(
sorted
(
indices
)
)
,
list
(
range
(
6
)))
self
.
assertEqual
(
sorted
(
indices
),
list
(
range
(
6
)))
def
test_group_by_length_with_batch_encoding
(
self
):
# Get some inputs of random lengths
...
...
@@ -219,7 +219,7 @@ class TrainerUtilsTest(unittest.TestCase):
# The biggest element should be first
self
.
assertEqual
(
len
(
data
[
indices
[
0
]][
"input_ids"
]),
105
)
# The indices should be a permutation of range(6)
self
.
assertEqual
(
list
(
sorted
(
indices
)
)
,
list
(
range
(
6
)))
self
.
assertEqual
(
sorted
(
indices
),
list
(
range
(
6
)))
def
test_distributed_length_grouped
(
self
):
# Get some inputs of random lengths
...
...
@@ -232,7 +232,7 @@ class TrainerUtilsTest(unittest.TestCase):
# The biggest element should be first
self
.
assertEqual
(
lengths
[
indices_process_0
[
0
]],
50
)
# The indices should be a permutation of range(100)
self
.
assertEqual
(
list
(
sorted
(
indices_process_0
+
indices_process_1
)
)
,
list
(
range
(
100
)))
self
.
assertEqual
(
sorted
(
indices_process_0
+
indices_process_1
),
list
(
range
(
100
)))
def
test_get_parameter_names
(
self
):
model
=
nn
.
Sequential
(
TstLayer
(
128
),
nn
.
ModuleList
([
TstLayer
(
128
),
TstLayer
(
128
)]))
...
...
tests/utils/test_modeling_tf_core.py
View file @
5e8c8eb5
...
...
@@ -285,7 +285,7 @@ class TFCoreModelTesterMixin:
del
inputs_dict
[
"decoder_head_mask"
]
if
"cross_attn_head_mask"
in
inputs_dict
:
del
inputs_dict
[
"cross_attn_head_mask"
]
tf_main_layer_classes
=
set
(
tf_main_layer_classes
=
{
module_member
for
model_class
in
self
.
all_model_classes
for
module
in
(
import_module
(
model_class
.
__module__
),)
...
...
@@ -295,7 +295,7 @@ class TFCoreModelTesterMixin:
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
)
}
for
main_layer_class
in
tf_main_layer_classes
:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
...
...
utils/check_copies.py
View file @
5e8c8eb5
...
...
@@ -385,7 +385,7 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
sorted_index
=
sorted
(
localized_model_index
.
items
(),
key
=
lambda
x
:
x
[
0
].
lower
())
return
readmes_match
,
"
\n
"
.
join
(
map
(
lambda
x
:
x
[
1
],
sorted_index
))
+
"
\n
"
return
readmes_match
,
"
\n
"
.
join
(
(
x
[
1
]
for
x
in
sorted_index
))
+
"
\n
"
def
convert_readme_to_index
(
model_list
):
...
...
utils/check_doc_toc.py
View file @
5e8c8eb5
...
...
@@ -33,7 +33,7 @@ def clean_model_doc_toc(model_doc):
new_doc
=
[]
for
duplicate_key
in
duplicates
:
titles
=
list
(
set
(
doc
[
"title"
]
for
doc
in
model_doc
if
doc
[
"local"
]
==
duplicate_key
)
)
titles
=
list
(
{
doc
[
"title"
]
for
doc
in
model_doc
if
doc
[
"local"
]
==
duplicate_key
}
)
if
len
(
titles
)
>
1
:
raise
ValueError
(
f
"
{
duplicate_key
}
is present several times in the documentation table of content at "
...
...
utils/check_repo.py
View file @
5e8c8eb5
...
...
@@ -335,7 +335,7 @@ def check_model_list():
# Get the models from the directory structure of `src/transformers/models/`
models
=
[
model
for
model
in
dir
(
transformers
.
models
)
if
not
model
.
startswith
(
"__"
)]
missing_models
=
sorted
(
list
(
set
(
_models
).
difference
(
models
))
)
missing_models
=
sorted
(
set
(
_models
).
difference
(
models
))
if
missing_models
:
raise
Exception
(
f
"The following models should be included in
{
models_dir
}
/__init__.py:
{
','
.
join
(
missing_models
)
}
."
...
...
@@ -547,7 +547,7 @@ def get_all_auto_configured_models():
for
attr_name
in
dir
(
transformers
.
models
.
auto
.
modeling_flax_auto
):
if
attr_name
.
startswith
(
"FLAX_MODEL_"
)
and
attr_name
.
endswith
(
"MAPPING_NAMES"
):
result
=
result
|
set
(
get_values
(
getattr
(
transformers
.
models
.
auto
.
modeling_flax_auto
,
attr_name
)))
return
[
cls
for
cls
in
result
]
return
list
(
result
)
def
ignore_unautoclassed
(
model_name
):
...
...
utils/create_dummy_models.py
View file @
5e8c8eb5
...
...
@@ -413,10 +413,10 @@ def convert_processors(processors, tiny_config, output_folder, result):
feature_extractors
.
append
(
processor
.
feature_extractor
)
# check the built processors have the unique type
num_types
=
len
(
set
([
x
.
__class__
.
__name__
for
x
in
feature_extractors
])
)
num_types
=
len
(
{
x
.
__class__
.
__name__
for
x
in
feature_extractors
}
)
if
num_types
>=
2
:
raise
ValueError
(
f
"`feature_extractors` should contain at most 1 type, but it contains
{
num_types
}
types!"
)
num_types
=
len
(
set
([
x
.
__class__
.
__name__
.
replace
(
"Fast"
,
""
)
for
x
in
tokenizers
])
)
num_types
=
len
(
{
x
.
__class__
.
__name__
.
replace
(
"Fast"
,
""
)
for
x
in
tokenizers
}
)
if
num_types
>=
2
:
raise
ValueError
(
f
"`tokenizers` should contain at most 1 tokenizer type, but it contains
{
num_types
}
types!"
)
...
...
@@ -712,7 +712,7 @@ def build_composite_models(config_class, output_dir):
shutil
.
copytree
(
decoder_processor_path
,
model_path
,
dirs_exist_ok
=
True
)
# fill `result`
result
[
"processor"
]
=
tuple
(
set
([
x
.
__name__
for
x
in
encoder_processor
+
decoder_processor
])
)
result
[
"processor"
]
=
tuple
(
{
x
.
__name__
for
x
in
encoder_processor
+
decoder_processor
}
)
result
[
"pytorch"
]
=
{
model_class
.
__name__
:
{
"model"
:
model_class
.
__name__
,
"checkpoint"
:
model_path
}}
...
...
utils/extract_warnings.py
View file @
5e8c8eb5
...
...
@@ -134,6 +134,6 @@ if __name__ == "__main__":
# extract warnings from artifacts
selected_warnings
=
extract_warnings
(
args
.
output_dir
,
args
.
targets
)
selected_warnings
=
sorted
(
list
(
selected_warnings
)
)
selected_warnings
=
sorted
(
selected_warnings
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"selected_warnings.json"
),
"w"
,
encoding
=
"UTF-8"
)
as
fp
:
json
.
dump
(
selected_warnings
,
fp
,
ensure_ascii
=
False
,
indent
=
4
)
utils/get_ci_error_statistics.py
View file @
5e8c8eb5
...
...
@@ -166,7 +166,7 @@ def reduce_by_model(logs, error_filter=None):
logs
=
[(
x
[
0
],
x
[
1
],
get_model
(
x
[
2
]))
for
x
in
logs
]
logs
=
[
x
for
x
in
logs
if
x
[
2
]
is
not
None
]
tests
=
set
([
x
[
2
]
for
x
in
logs
])
tests
=
{
x
[
2
]
for
x
in
logs
}
r
=
{}
for
test
in
tests
:
...
...
utils/tests_fetcher.py
View file @
5e8c8eb5
...
...
@@ -78,13 +78,11 @@ def get_all_tests():
# test folders/files directly under `tests` folder
tests
=
os
.
listdir
(
test_root_dir
)
tests
=
sorted
(
list
(
filter
(
lambda
x
:
os
.
path
.
isdir
(
x
)
or
x
.
startswith
(
"tests/test_"
),
[
f
"tests/
{
x
}
"
for
x
in
tests
]))
)
tests
=
sorted
(
filter
(
lambda
x
:
os
.
path
.
isdir
(
x
)
or
x
.
startswith
(
"tests/test_"
),
[
f
"tests/
{
x
}
"
for
x
in
tests
]))
# model specific test folders
model_tests_folders
=
os
.
listdir
(
os
.
path
.
join
(
test_root_dir
,
"models"
))
model_test_folders
=
sorted
(
list
(
filter
(
os
.
path
.
isdir
,
[
f
"tests/models/
{
x
}
"
for
x
in
model_tests_folders
]))
)
model_test_folders
=
sorted
(
filter
(
os
.
path
.
isdir
,
[
f
"tests/models/
{
x
}
"
for
x
in
model_tests_folders
]))
tests
.
remove
(
"tests/models"
)
tests
=
model_test_folders
+
tests
...
...
@@ -265,7 +263,7 @@ def get_tree_starting_at(module, edges):
tree
=
[
module
]
while
len
(
new_edges
)
>
0
:
tree
.
append
(
new_edges
)
final_vertices
=
list
(
set
(
edge
[
1
]
for
edge
in
new_edges
)
)
final_vertices
=
list
(
{
edge
[
1
]
for
edge
in
new_edges
}
)
vertices_seen
.
extend
(
final_vertices
)
new_edges
=
[
edge
for
edge
in
edges
if
edge
[
0
]
in
final_vertices
and
edge
[
1
]
not
in
vertices_seen
]
...
...
@@ -285,10 +283,10 @@ def print_tree_deps_of(module, all_edges=None):
lines
=
[(
tree
[
0
],
tree
[
0
])]
for
index
in
range
(
1
,
len
(
tree
)):
edges
=
tree
[
index
]
start_edges
=
set
([
edge
[
0
]
for
edge
in
edges
])
start_edges
=
{
edge
[
0
]
for
edge
in
edges
}
for
start
in
start_edges
:
end_edges
=
set
([
edge
[
1
]
for
edge
in
edges
if
edge
[
0
]
==
start
])
end_edges
=
{
edge
[
1
]
for
edge
in
edges
if
edge
[
0
]
==
start
}
# We will insert all those edges just after the line showing start.
pos
=
0
while
lines
[
pos
][
1
]
!=
start
:
...
...
@@ -547,7 +545,7 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, j
impacted_files
.
extend
(
impacted_modules_map
[
f
])
# Remove duplicates
impacted_files
=
sorted
(
list
(
set
(
impacted_files
))
)
impacted_files
=
sorted
(
set
(
impacted_files
))
print
(
f
"
\n
### IMPACTED FILES ###
\n
{
_print_list
(
impacted_files
)
}
"
)
# Grab the corresponding test files:
...
...
@@ -578,7 +576,7 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, j
test_files_to_run
.
extend
(
new_tests
)
# Remove duplicates
test_files_to_run
=
sorted
(
list
(
set
(
test_files_to_run
))
)
test_files_to_run
=
sorted
(
set
(
test_files_to_run
))
# Make sure we did not end up with a test file that was removed
test_files_to_run
=
[
f
for
f
in
test_files_to_run
if
os
.
path
.
isfile
(
f
)
or
os
.
path
.
isdir
(
f
)]
if
filters
is
not
None
:
...
...
utils/update_metadata.py
View file @
5e8c8eb5
...
...
@@ -223,7 +223,7 @@ def update_metadata(token, commit_sha):
table
=
update_pipeline_and_auto_class_table
(
table
)
# Sort the model classes to avoid some nondeterministic updates to create false update commits.
model_classes
=
sorted
(
list
(
table
.
keys
())
)
model_classes
=
sorted
(
table
.
keys
())
tags_table
=
pd
.
DataFrame
(
{
"model_class"
:
model_classes
,
...
...
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment