Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
77 additions
and
95 deletions
+77
-95
examples/flax/image-captioning/run_image_captioning_flax.py
examples/flax/image-captioning/run_image_captioning_flax.py
+6
-8
examples/flax/language-modeling/run_bart_dlm_flax.py
examples/flax/language-modeling/run_bart_dlm_flax.py
+6
-8
examples/flax/language-modeling/run_clm_flax.py
examples/flax/language-modeling/run_clm_flax.py
+6
-8
examples/flax/language-modeling/run_mlm_flax.py
examples/flax/language-modeling/run_mlm_flax.py
+6
-8
examples/flax/language-modeling/run_t5_mlm_flax.py
examples/flax/language-modeling/run_t5_mlm_flax.py
+6
-8
examples/flax/question-answering/run_qa.py
examples/flax/question-answering/run_qa.py
+7
-9
examples/flax/summarization/run_summarization_flax.py
examples/flax/summarization/run_summarization_flax.py
+6
-8
examples/flax/text-classification/run_flax_glue.py
examples/flax/text-classification/run_flax_glue.py
+8
-10
examples/flax/token-classification/run_flax_ner.py
examples/flax/token-classification/run_flax_ner.py
+6
-8
examples/legacy/pytorch-lightning/run_glue.py
examples/legacy/pytorch-lightning/run_glue.py
+1
-1
examples/legacy/pytorch-lightning/run_ner.py
examples/legacy/pytorch-lightning/run_ner.py
+1
-1
examples/legacy/question-answering/run_squad.py
examples/legacy/question-answering/run_squad.py
+3
-3
examples/legacy/run_openai_gpt.py
examples/legacy/run_openai_gpt.py
+1
-1
examples/legacy/run_swag.py
examples/legacy/run_swag.py
+3
-3
examples/legacy/seq2seq/run_distributed_eval.py
examples/legacy/seq2seq/run_distributed_eval.py
+2
-2
examples/legacy/seq2seq/run_eval.py
examples/legacy/seq2seq/run_eval.py
+1
-1
examples/legacy/seq2seq/run_eval_search.py
examples/legacy/seq2seq/run_eval_search.py
+1
-1
examples/legacy/seq2seq/utils.py
examples/legacy/seq2seq/utils.py
+1
-1
examples/pytorch/audio-classification/run_audio_classification.py
.../pytorch/audio-classification/run_audio_classification.py
+3
-3
examples/pytorch/benchmarking/plot_csv_file.py
examples/pytorch/benchmarking/plot_csv_file.py
+3
-3
No files found.
examples/flax/image-captioning/run_image_captioning_flax.py
View file @
5e8c8eb5
...
...
@@ -892,14 +892,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/language-modeling/run_bart_dlm_flax.py
View file @
5e8c8eb5
...
...
@@ -756,14 +756,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/language-modeling/run_clm_flax.py
View file @
5e8c8eb5
...
...
@@ -648,14 +648,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/language-modeling/run_mlm_flax.py
View file @
5e8c8eb5
...
...
@@ -679,14 +679,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/language-modeling/run_t5_mlm_flax.py
View file @
5e8c8eb5
...
...
@@ -791,14 +791,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/question-answering/run_qa.py
View file @
5e8c8eb5
...
...
@@ -333,14 +333,12 @@ def create_train_state(
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
@@ -642,7 +640,7 @@ def main():
return
tokenized_examples
processed_raw_datasets
=
dict
()
processed_raw_datasets
=
{}
if
training_args
.
do_train
:
if
"train"
not
in
raw_datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
...
...
examples/flax/summarization/run_summarization_flax.py
View file @
5e8c8eb5
...
...
@@ -742,14 +742,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/text-classification/run_flax_glue.py
View file @
5e8c8eb5
...
...
@@ -229,14 +229,12 @@ def create_train_state(
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
@@ -449,7 +447,7 @@ def main():
):
# Some have all caps in their config, some don't.
label_name_to_id
=
{
k
.
lower
():
v
for
k
,
v
in
model
.
config
.
label2id
.
items
()}
if
list
(
sorted
(
label_name_to_id
.
keys
())
)
==
list
(
sorted
(
label_list
)
)
:
if
sorted
(
label_name_to_id
.
keys
())
==
sorted
(
label_list
):
logger
.
info
(
f
"The configuration of the model provided the following label correspondence:
{
label_name_to_id
}
. "
"Using it!"
...
...
@@ -458,7 +456,7 @@ def main():
else
:
logger
.
warning
(
"Your model seems to have been trained with labels, but they don't match the dataset: "
,
f
"model labels:
{
list
(
sorted
(
label_name_to_id
.
keys
())
)
}
, dataset labels:
{
list
(
sorted
(
label_list
)
)
}
."
f
"model labels:
{
sorted
(
label_name_to_id
.
keys
())
}
, dataset labels:
{
sorted
(
label_list
)
}
."
"
\n
Ignoring the model labels as a result."
,
)
elif
data_args
.
task_name
is
None
:
...
...
examples/flax/token-classification/run_flax_ner.py
View file @
5e8c8eb5
...
...
@@ -290,14 +290,12 @@ def create_train_state(
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
[
layer_norm_named_params
=
{
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
]
)
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/legacy/pytorch-lightning/run_glue.py
View file @
5e8c8eb5
...
...
@@ -192,7 +192,7 @@ def main():
# Optionally, predict on dev set and write to output_dir
if
args
.
do_predict
:
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-epoch=*.ckpt"
),
recursive
=
True
))
)
checkpoints
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-epoch=*.ckpt"
),
recursive
=
True
))
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
return
trainer
.
test
(
model
)
...
...
examples/legacy/pytorch-lightning/run_ner.py
View file @
5e8c8eb5
...
...
@@ -211,6 +211,6 @@ if __name__ == "__main__":
# pl use this default format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-epoch=*.ckpt"
),
recursive
=
True
))
)
checkpoints
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-epoch=*.ckpt"
),
recursive
=
True
))
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
examples/legacy/question-answering/run_squad.py
View file @
5e8c8eb5
...
...
@@ -810,10 +810,10 @@ def main():
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
else
:
logger
.
info
(
"Loading checkpoint %s for evaluation"
,
args
.
model_name_or_path
)
...
...
@@ -830,7 +830,7 @@ def main():
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
examples/legacy/run_openai_gpt.py
View file @
5e8c8eb5
...
...
@@ -189,7 +189,7 @@ def main():
return
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
obj
))
elif
isinstance
(
obj
,
int
):
return
obj
return
list
(
tokenize_and_encode
(
o
)
for
o
in
obj
)
return
[
tokenize_and_encode
(
o
)
for
o
in
obj
]
logger
.
info
(
"Encoding dataset..."
)
train_dataset
=
load_rocstories_dataset
(
args
.
train_dataset
)
...
...
examples/legacy/run_swag.py
View file @
5e8c8eb5
...
...
@@ -696,9 +696,9 @@ def main():
checkpoints
=
[
args
.
model_name_or_path
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
...
...
@@ -712,7 +712,7 @@ def main():
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
examples/legacy/seq2seq/run_distributed_eval.py
View file @
5e8c8eb5
...
...
@@ -111,7 +111,7 @@ def eval_data_dir(
if
num_return_sequences
>
1
:
preds
=
chunks
(
preds
,
num_return_sequences
)
# batch size chunks, each of size num_return_seq
for
i
,
pred
in
enumerate
(
preds
):
results
.
append
(
dict
(
pred
=
pred
,
id
=
ids
[
i
].
item
()
)
)
results
.
append
(
{
"
pred
"
:
pred
,
"
id
"
:
ids
[
i
].
item
()
}
)
save_json
(
results
,
save_path
)
return
results
,
sampler
.
num_replicas
...
...
@@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List:
records
=
[]
for
partial_result
in
partial_results
:
records
.
extend
(
partial_result
)
records
=
list
(
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
])
)
records
=
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
])
preds
=
[
x
[
"pred"
]
for
x
in
records
]
return
preds
...
...
examples/legacy/seq2seq/run_eval.py
View file @
5e8c8eb5
...
...
@@ -76,7 +76,7 @@ def generate_summaries_or_translations(
fout
.
close
()
runtime
=
int
(
time
.
time
()
-
start_time
)
# seconds
n_obs
=
len
(
examples
)
return
dict
(
n_obs
=
n_obs
,
runtime
=
runtime
,
seconds_per_sample
=
round
(
runtime
/
n_obs
,
4
)
)
return
{
"
n_obs
"
:
n_obs
,
"
runtime
"
:
runtime
,
"
seconds_per_sample
"
:
round
(
runtime
/
n_obs
,
4
)
}
def
datetime_now
():
...
...
examples/legacy/seq2seq/run_eval_search.py
View file @
5e8c8eb5
...
...
@@ -36,7 +36,7 @@ def parse_search_arg(search):
groups
=
search
.
split
()
entries
=
{
k
:
vs
for
k
,
vs
in
(
g
.
split
(
"="
)
for
g
in
groups
)}
entry_names
=
list
(
entries
.
keys
())
sets
=
[
list
(
f
"--
{
k
}
{
v
}
"
for
v
in
vs
.
split
(
":"
)
)
for
k
,
vs
in
entries
.
items
()]
sets
=
[
[
f
"--
{
k
}
{
v
}
"
for
v
in
vs
.
split
(
":"
)
]
for
k
,
vs
in
entries
.
items
()]
matrix
=
[
list
(
x
)
for
x
in
itertools
.
product
(
*
sets
)]
return
matrix
,
entry_names
...
...
examples/legacy/seq2seq/utils.py
View file @
5e8c8eb5
...
...
@@ -456,7 +456,7 @@ def pickle_save(obj, path):
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)
]
return
list
(
itertools
.
chain
.
from_iterable
(
summary_ids
)
)
def
save_git_info
(
folder_path
:
str
)
->
None
:
...
...
examples/pytorch/audio-classification/run_audio_classification.py
View file @
5e8c8eb5
...
...
@@ -293,7 +293,7 @@ def main():
audio
[
"array"
],
max_length
=
data_args
.
max_length_seconds
,
sample_rate
=
feature_extractor
.
sampling_rate
)
output_batch
[
"input_values"
].
append
(
wav
)
output_batch
[
"labels"
]
=
[
label
for
label
in
batch
[
data_args
.
label_column_name
]
]
output_batch
[
"labels"
]
=
list
(
batch
[
data_args
.
label_column_name
]
)
return
output_batch
...
...
@@ -303,14 +303,14 @@ def main():
for
audio
in
batch
[
data_args
.
audio_column_name
]:
wav
=
audio
[
"array"
]
output_batch
[
"input_values"
].
append
(
wav
)
output_batch
[
"labels"
]
=
[
label
for
label
in
batch
[
data_args
.
label_column_name
]
]
output_batch
[
"labels"
]
=
list
(
batch
[
data_args
.
label_column_name
]
)
return
output_batch
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels
=
raw_datasets
[
"train"
].
features
[
data_args
.
label_column_name
].
names
label2id
,
id2label
=
dict
(),
dict
()
label2id
,
id2label
=
{},
{}
for
i
,
label
in
enumerate
(
labels
):
label2id
[
label
]
=
str
(
i
)
id2label
[
str
(
i
)]
=
label
...
...
examples/pytorch/benchmarking/plot_csv_file.py
View file @
5e8c8eb5
...
...
@@ -83,7 +83,7 @@ def can_convert_to_float(string):
class
Plot
:
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
result_dict
=
defaultdict
(
lambda
:
dict
(
bsz
=
[],
seq_len
=
[],
result
=
{}
)
)
self
.
result_dict
=
defaultdict
(
lambda
:
{
"bsz"
:
[],
"
seq_len
"
:
[],
"
result
"
:
{}
}
)
with
open
(
self
.
args
.
csv_file
,
newline
=
""
)
as
csv_file
:
reader
=
csv
.
DictReader
(
csv_file
)
...
...
@@ -116,8 +116,8 @@ class Plot:
axis
.
set_major_formatter
(
ScalarFormatter
())
for
model_name_idx
,
model_name
in
enumerate
(
self
.
result_dict
.
keys
()):
batch_sizes
=
sorted
(
list
(
set
(
self
.
result_dict
[
model_name
][
"bsz"
]))
)
sequence_lengths
=
sorted
(
list
(
set
(
self
.
result_dict
[
model_name
][
"seq_len"
]))
)
batch_sizes
=
sorted
(
set
(
self
.
result_dict
[
model_name
][
"bsz"
]))
sequence_lengths
=
sorted
(
set
(
self
.
result_dict
[
model_name
][
"seq_len"
]))
results
=
self
.
result_dict
[
model_name
][
"result"
]
(
x_axis_array
,
inner_loop_array
)
=
(
...
...
Prev
1
2
3
4
5
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment