Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "364920e216c16d73c782a61a4cf6652e541fbe18"
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
77 additions
and
95 deletions
+77
-95
examples/flax/image-captioning/run_image_captioning_flax.py
examples/flax/image-captioning/run_image_captioning_flax.py
+6
-8
examples/flax/language-modeling/run_bart_dlm_flax.py
examples/flax/language-modeling/run_bart_dlm_flax.py
+6
-8
examples/flax/language-modeling/run_clm_flax.py
examples/flax/language-modeling/run_clm_flax.py
+6
-8
examples/flax/language-modeling/run_mlm_flax.py
examples/flax/language-modeling/run_mlm_flax.py
+6
-8
examples/flax/language-modeling/run_t5_mlm_flax.py
examples/flax/language-modeling/run_t5_mlm_flax.py
+6
-8
examples/flax/question-answering/run_qa.py
examples/flax/question-answering/run_qa.py
+7
-9
examples/flax/summarization/run_summarization_flax.py
examples/flax/summarization/run_summarization_flax.py
+6
-8
examples/flax/text-classification/run_flax_glue.py
examples/flax/text-classification/run_flax_glue.py
+8
-10
examples/flax/token-classification/run_flax_ner.py
examples/flax/token-classification/run_flax_ner.py
+6
-8
examples/legacy/pytorch-lightning/run_glue.py
examples/legacy/pytorch-lightning/run_glue.py
+1
-1
examples/legacy/pytorch-lightning/run_ner.py
examples/legacy/pytorch-lightning/run_ner.py
+1
-1
examples/legacy/question-answering/run_squad.py
examples/legacy/question-answering/run_squad.py
+3
-3
examples/legacy/run_openai_gpt.py
examples/legacy/run_openai_gpt.py
+1
-1
examples/legacy/run_swag.py
examples/legacy/run_swag.py
+3
-3
examples/legacy/seq2seq/run_distributed_eval.py
examples/legacy/seq2seq/run_distributed_eval.py
+2
-2
examples/legacy/seq2seq/run_eval.py
examples/legacy/seq2seq/run_eval.py
+1
-1
examples/legacy/seq2seq/run_eval_search.py
examples/legacy/seq2seq/run_eval_search.py
+1
-1
examples/legacy/seq2seq/utils.py
examples/legacy/seq2seq/utils.py
+1
-1
examples/pytorch/audio-classification/run_audio_classification.py
.../pytorch/audio-classification/run_audio_classification.py
+3
-3
examples/pytorch/benchmarking/plot_csv_file.py
examples/pytorch/benchmarking/plot_csv_file.py
+3
-3
No files found.
examples/flax/image-captioning/run_image_captioning_flax.py
View file @
5e8c8eb5
...
@@ -892,14 +892,12 @@ def main():
...
@@ -892,14 +892,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/language-modeling/run_bart_dlm_flax.py
View file @
5e8c8eb5
...
@@ -756,14 +756,12 @@ def main():
...
@@ -756,14 +756,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/language-modeling/run_clm_flax.py
View file @
5e8c8eb5
...
@@ -648,14 +648,12 @@ def main():
...
@@ -648,14 +648,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/language-modeling/run_mlm_flax.py
View file @
5e8c8eb5
...
@@ -679,14 +679,12 @@ def main():
...
@@ -679,14 +679,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/language-modeling/run_t5_mlm_flax.py
View file @
5e8c8eb5
...
@@ -791,14 +791,12 @@ def main():
...
@@ -791,14 +791,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/question-answering/run_qa.py
View file @
5e8c8eb5
...
@@ -333,14 +333,12 @@ def create_train_state(
...
@@ -333,14 +333,12 @@ def create_train_state(
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
@@ -642,7 +640,7 @@ def main():
...
@@ -642,7 +640,7 @@ def main():
return
tokenized_examples
return
tokenized_examples
processed_raw_datasets
=
dict
()
processed_raw_datasets
=
{}
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
"train"
not
in
raw_datasets
:
if
"train"
not
in
raw_datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
raise
ValueError
(
"--do_train requires a train dataset"
)
...
...
examples/flax/summarization/run_summarization_flax.py
View file @
5e8c8eb5
...
@@ -742,14 +742,12 @@ def main():
...
@@ -742,14 +742,12 @@ def main():
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/flax/text-classification/run_flax_glue.py
View file @
5e8c8eb5
...
@@ -229,14 +229,12 @@ def create_train_state(
...
@@ -229,14 +229,12 @@ def create_train_state(
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
@@ -449,7 +447,7 @@ def main():
...
@@ -449,7 +447,7 @@ def main():
):
):
# Some have all caps in their config, some don't.
# Some have all caps in their config, some don't.
label_name_to_id
=
{
k
.
lower
():
v
for
k
,
v
in
model
.
config
.
label2id
.
items
()}
label_name_to_id
=
{
k
.
lower
():
v
for
k
,
v
in
model
.
config
.
label2id
.
items
()}
if
list
(
sorted
(
label_name_to_id
.
keys
())
)
==
list
(
sorted
(
label_list
)
)
:
if
sorted
(
label_name_to_id
.
keys
())
==
sorted
(
label_list
):
logger
.
info
(
logger
.
info
(
f
"The configuration of the model provided the following label correspondence:
{
label_name_to_id
}
. "
f
"The configuration of the model provided the following label correspondence:
{
label_name_to_id
}
. "
"Using it!"
"Using it!"
...
@@ -458,7 +456,7 @@ def main():
...
@@ -458,7 +456,7 @@ def main():
else
:
else
:
logger
.
warning
(
logger
.
warning
(
"Your model seems to have been trained with labels, but they don't match the dataset: "
,
"Your model seems to have been trained with labels, but they don't match the dataset: "
,
f
"model labels:
{
list
(
sorted
(
label_name_to_id
.
keys
())
)
}
, dataset labels:
{
list
(
sorted
(
label_list
)
)
}
."
f
"model labels:
{
sorted
(
label_name_to_id
.
keys
())
}
, dataset labels:
{
sorted
(
label_list
)
}
."
"
\n
Ignoring the model labels as a result."
,
"
\n
Ignoring the model labels as a result."
,
)
)
elif
data_args
.
task_name
is
None
:
elif
data_args
.
task_name
is
None
:
...
...
examples/flax/token-classification/run_flax_ner.py
View file @
5e8c8eb5
...
@@ -290,14 +290,12 @@ def create_train_state(
...
@@ -290,14 +290,12 @@ def create_train_state(
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_params
=
traverse_util
.
flatten_dict
(
params
)
# find out all LayerNorm parameters
# find out all LayerNorm parameters
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_candidates
=
[
"layernorm"
,
"layer_norm"
,
"ln"
]
layer_norm_named_params
=
set
(
layer_norm_named_params
=
{
[
layer
[
-
2
:]
layer
[
-
2
:]
for
layer_norm_name
in
layer_norm_candidates
for
layer_norm_name
in
layer_norm_candidates
for
layer
in
flat_params
.
keys
()
for
layer
in
flat_params
.
keys
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
if
layer_norm_name
in
""
.
join
(
layer
).
lower
()
}
]
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
not
in
layer_norm_named_params
)
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
return
traverse_util
.
unflatten_dict
(
flat_mask
)
...
...
examples/legacy/pytorch-lightning/run_glue.py
View file @
5e8c8eb5
...
@@ -192,7 +192,7 @@ def main():
...
@@ -192,7 +192,7 @@ def main():
# Optionally, predict on dev set and write to output_dir
# Optionally, predict on dev set and write to output_dir
if
args
.
do_predict
:
if
args
.
do_predict
:
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-epoch=*.ckpt"
),
recursive
=
True
))
)
checkpoints
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-epoch=*.ckpt"
),
recursive
=
True
))
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
return
trainer
.
test
(
model
)
return
trainer
.
test
(
model
)
...
...
examples/legacy/pytorch-lightning/run_ner.py
View file @
5e8c8eb5
...
@@ -211,6 +211,6 @@ if __name__ == "__main__":
...
@@ -211,6 +211,6 @@ if __name__ == "__main__":
# pl use this default format to create a checkpoint:
# pl use this default format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-epoch=*.ckpt"
),
recursive
=
True
))
)
checkpoints
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-epoch=*.ckpt"
),
recursive
=
True
))
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
trainer
.
test
(
model
)
examples/legacy/question-answering/run_squad.py
View file @
5e8c8eb5
...
@@ -810,10 +810,10 @@ def main():
...
@@ -810,10 +810,10 @@ def main():
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
checkpoints
=
[
args
.
output_dir
]
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
else
:
else
:
logger
.
info
(
"Loading checkpoint %s for evaluation"
,
args
.
model_name_or_path
)
logger
.
info
(
"Loading checkpoint %s for evaluation"
,
args
.
model_name_or_path
)
...
@@ -830,7 +830,7 @@ def main():
...
@@ -830,7 +830,7 @@ def main():
# Evaluate
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
examples/legacy/run_openai_gpt.py
View file @
5e8c8eb5
...
@@ -189,7 +189,7 @@ def main():
...
@@ -189,7 +189,7 @@ def main():
return
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
obj
))
return
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
obj
))
elif
isinstance
(
obj
,
int
):
elif
isinstance
(
obj
,
int
):
return
obj
return
obj
return
list
(
tokenize_and_encode
(
o
)
for
o
in
obj
)
return
[
tokenize_and_encode
(
o
)
for
o
in
obj
]
logger
.
info
(
"Encoding dataset..."
)
logger
.
info
(
"Encoding dataset..."
)
train_dataset
=
load_rocstories_dataset
(
args
.
train_dataset
)
train_dataset
=
load_rocstories_dataset
(
args
.
train_dataset
)
...
...
examples/legacy/run_swag.py
View file @
5e8c8eb5
...
@@ -696,9 +696,9 @@ def main():
...
@@ -696,9 +696,9 @@ def main():
checkpoints
=
[
args
.
model_name_or_path
]
checkpoints
=
[
args
.
model_name_or_path
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
...
@@ -712,7 +712,7 @@ def main():
...
@@ -712,7 +712,7 @@ def main():
# Evaluate
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
examples/legacy/seq2seq/run_distributed_eval.py
View file @
5e8c8eb5
...
@@ -111,7 +111,7 @@ def eval_data_dir(
...
@@ -111,7 +111,7 @@ def eval_data_dir(
if
num_return_sequences
>
1
:
if
num_return_sequences
>
1
:
preds
=
chunks
(
preds
,
num_return_sequences
)
# batch size chunks, each of size num_return_seq
preds
=
chunks
(
preds
,
num_return_sequences
)
# batch size chunks, each of size num_return_seq
for
i
,
pred
in
enumerate
(
preds
):
for
i
,
pred
in
enumerate
(
preds
):
results
.
append
(
dict
(
pred
=
pred
,
id
=
ids
[
i
].
item
()
)
)
results
.
append
(
{
"
pred
"
:
pred
,
"
id
"
:
ids
[
i
].
item
()
}
)
save_json
(
results
,
save_path
)
save_json
(
results
,
save_path
)
return
results
,
sampler
.
num_replicas
return
results
,
sampler
.
num_replicas
...
@@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List:
...
@@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List:
records
=
[]
records
=
[]
for
partial_result
in
partial_results
:
for
partial_result
in
partial_results
:
records
.
extend
(
partial_result
)
records
.
extend
(
partial_result
)
records
=
list
(
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
])
)
records
=
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
])
preds
=
[
x
[
"pred"
]
for
x
in
records
]
preds
=
[
x
[
"pred"
]
for
x
in
records
]
return
preds
return
preds
...
...
examples/legacy/seq2seq/run_eval.py
View file @
5e8c8eb5
...
@@ -76,7 +76,7 @@ def generate_summaries_or_translations(
...
@@ -76,7 +76,7 @@ def generate_summaries_or_translations(
fout
.
close
()
fout
.
close
()
runtime
=
int
(
time
.
time
()
-
start_time
)
# seconds
runtime
=
int
(
time
.
time
()
-
start_time
)
# seconds
n_obs
=
len
(
examples
)
n_obs
=
len
(
examples
)
return
dict
(
n_obs
=
n_obs
,
runtime
=
runtime
,
seconds_per_sample
=
round
(
runtime
/
n_obs
,
4
)
)
return
{
"
n_obs
"
:
n_obs
,
"
runtime
"
:
runtime
,
"
seconds_per_sample
"
:
round
(
runtime
/
n_obs
,
4
)
}
def
datetime_now
():
def
datetime_now
():
...
...
examples/legacy/seq2seq/run_eval_search.py
View file @
5e8c8eb5
...
@@ -36,7 +36,7 @@ def parse_search_arg(search):
...
@@ -36,7 +36,7 @@ def parse_search_arg(search):
groups
=
search
.
split
()
groups
=
search
.
split
()
entries
=
{
k
:
vs
for
k
,
vs
in
(
g
.
split
(
"="
)
for
g
in
groups
)}
entries
=
{
k
:
vs
for
k
,
vs
in
(
g
.
split
(
"="
)
for
g
in
groups
)}
entry_names
=
list
(
entries
.
keys
())
entry_names
=
list
(
entries
.
keys
())
sets
=
[
list
(
f
"--
{
k
}
{
v
}
"
for
v
in
vs
.
split
(
":"
)
)
for
k
,
vs
in
entries
.
items
()]
sets
=
[
[
f
"--
{
k
}
{
v
}
"
for
v
in
vs
.
split
(
":"
)
]
for
k
,
vs
in
entries
.
items
()]
matrix
=
[
list
(
x
)
for
x
in
itertools
.
product
(
*
sets
)]
matrix
=
[
list
(
x
)
for
x
in
itertools
.
product
(
*
sets
)]
return
matrix
,
entry_names
return
matrix
,
entry_names
...
...
examples/legacy/seq2seq/utils.py
View file @
5e8c8eb5
...
@@ -456,7 +456,7 @@ def pickle_save(obj, path):
...
@@ -456,7 +456,7 @@ def pickle_save(obj, path):
def
flatten_list
(
summary_ids
:
List
[
List
]):
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)
]
return
list
(
itertools
.
chain
.
from_iterable
(
summary_ids
)
)
def
save_git_info
(
folder_path
:
str
)
->
None
:
def
save_git_info
(
folder_path
:
str
)
->
None
:
...
...
examples/pytorch/audio-classification/run_audio_classification.py
View file @
5e8c8eb5
...
@@ -293,7 +293,7 @@ def main():
...
@@ -293,7 +293,7 @@ def main():
audio
[
"array"
],
max_length
=
data_args
.
max_length_seconds
,
sample_rate
=
feature_extractor
.
sampling_rate
audio
[
"array"
],
max_length
=
data_args
.
max_length_seconds
,
sample_rate
=
feature_extractor
.
sampling_rate
)
)
output_batch
[
"input_values"
].
append
(
wav
)
output_batch
[
"input_values"
].
append
(
wav
)
output_batch
[
"labels"
]
=
[
label
for
label
in
batch
[
data_args
.
label_column_name
]
]
output_batch
[
"labels"
]
=
list
(
batch
[
data_args
.
label_column_name
]
)
return
output_batch
return
output_batch
...
@@ -303,14 +303,14 @@ def main():
...
@@ -303,14 +303,14 @@ def main():
for
audio
in
batch
[
data_args
.
audio_column_name
]:
for
audio
in
batch
[
data_args
.
audio_column_name
]:
wav
=
audio
[
"array"
]
wav
=
audio
[
"array"
]
output_batch
[
"input_values"
].
append
(
wav
)
output_batch
[
"input_values"
].
append
(
wav
)
output_batch
[
"labels"
]
=
[
label
for
label
in
batch
[
data_args
.
label_column_name
]
]
output_batch
[
"labels"
]
=
list
(
batch
[
data_args
.
label_column_name
]
)
return
output_batch
return
output_batch
# Prepare label mappings.
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels
=
raw_datasets
[
"train"
].
features
[
data_args
.
label_column_name
].
names
labels
=
raw_datasets
[
"train"
].
features
[
data_args
.
label_column_name
].
names
label2id
,
id2label
=
dict
(),
dict
()
label2id
,
id2label
=
{},
{}
for
i
,
label
in
enumerate
(
labels
):
for
i
,
label
in
enumerate
(
labels
):
label2id
[
label
]
=
str
(
i
)
label2id
[
label
]
=
str
(
i
)
id2label
[
str
(
i
)]
=
label
id2label
[
str
(
i
)]
=
label
...
...
examples/pytorch/benchmarking/plot_csv_file.py
View file @
5e8c8eb5
...
@@ -83,7 +83,7 @@ def can_convert_to_float(string):
...
@@ -83,7 +83,7 @@ def can_convert_to_float(string):
class
Plot
:
class
Plot
:
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
args
=
args
self
.
result_dict
=
defaultdict
(
lambda
:
dict
(
bsz
=
[],
seq_len
=
[],
result
=
{}
)
)
self
.
result_dict
=
defaultdict
(
lambda
:
{
"bsz"
:
[],
"
seq_len
"
:
[],
"
result
"
:
{}
}
)
with
open
(
self
.
args
.
csv_file
,
newline
=
""
)
as
csv_file
:
with
open
(
self
.
args
.
csv_file
,
newline
=
""
)
as
csv_file
:
reader
=
csv
.
DictReader
(
csv_file
)
reader
=
csv
.
DictReader
(
csv_file
)
...
@@ -116,8 +116,8 @@ class Plot:
...
@@ -116,8 +116,8 @@ class Plot:
axis
.
set_major_formatter
(
ScalarFormatter
())
axis
.
set_major_formatter
(
ScalarFormatter
())
for
model_name_idx
,
model_name
in
enumerate
(
self
.
result_dict
.
keys
()):
for
model_name_idx
,
model_name
in
enumerate
(
self
.
result_dict
.
keys
()):
batch_sizes
=
sorted
(
list
(
set
(
self
.
result_dict
[
model_name
][
"bsz"
]))
)
batch_sizes
=
sorted
(
set
(
self
.
result_dict
[
model_name
][
"bsz"
]))
sequence_lengths
=
sorted
(
list
(
set
(
self
.
result_dict
[
model_name
][
"seq_len"
]))
)
sequence_lengths
=
sorted
(
set
(
self
.
result_dict
[
model_name
][
"seq_len"
]))
results
=
self
.
result_dict
[
model_name
][
"result"
]
results
=
self
.
result_dict
[
model_name
][
"result"
]
(
x_axis_array
,
inner_loop_array
)
=
(
(
x_axis_array
,
inner_loop_array
)
=
(
...
...
Prev
1
2
3
4
5
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment