Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
58 additions
and
57 deletions
+58
-57
examples/research_projects/deebert/run_glue_deebert.py
examples/research_projects/deebert/run_glue_deebert.py
+3
-3
examples/research_projects/distillation/grouped_batch_sampler.py
...s/research_projects/distillation/grouped_batch_sampler.py
+1
-1
examples/research_projects/distillation/run_squad_w_distillation.py
...esearch_projects/distillation/run_squad_w_distillation.py
+3
-3
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
...s/research_projects/jax-projects/big_bird/bigbird_flax.py
+6
-3
examples/research_projects/jax-projects/big_bird/evaluate.py
examples/research_projects/jax-projects/big_bird/evaluate.py
+3
-3
examples/research_projects/jax-projects/big_bird/prepare_natural_questions.py
...ojects/jax-projects/big_bird/prepare_natural_questions.py
+6
-6
examples/research_projects/jax-projects/model_parallel/partitions.py
...search_projects/jax-projects/model_parallel/partitions.py
+1
-1
examples/research_projects/longform-qa/eli5_utils.py
examples/research_projects/longform-qa/eli5_utils.py
+6
-6
examples/research_projects/lxmert/extracting_data.py
examples/research_projects/lxmert/extracting_data.py
+1
-1
examples/research_projects/lxmert/modeling_frcnn.py
examples/research_projects/lxmert/modeling_frcnn.py
+1
-1
examples/research_projects/mm-imdb/run_mmimdb.py
examples/research_projects/mm-imdb/run_mmimdb.py
+3
-3
examples/research_projects/movement-pruning/masked_run_glue.py
...les/research_projects/movement-pruning/masked_run_glue.py
+3
-3
examples/research_projects/movement-pruning/masked_run_squad.py
...es/research_projects/movement-pruning/masked_run_squad.py
+3
-3
examples/research_projects/onnx/summarization/bart_onnx/reduce_onnx_size.py
...projects/onnx/summarization/bart_onnx/reduce_onnx_size.py
+3
-3
examples/research_projects/pplm/run_pplm.py
examples/research_projects/pplm/run_pplm.py
+2
-4
examples/research_projects/rag-end2end-retriever/finetune_rag.py
...s/research_projects/rag-end2end-retriever/finetune_rag.py
+5
-5
examples/research_projects/rag-end2end-retriever/utils_rag.py
...ples/research_projects/rag-end2end-retriever/utils_rag.py
+1
-1
examples/research_projects/rag/finetune_rag.py
examples/research_projects/rag/finetune_rag.py
+5
-5
examples/research_projects/rag/utils_rag.py
examples/research_projects/rag/utils_rag.py
+1
-1
examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
...cts/robust-speech-event/run_speech_recognition_ctc_bnb.py
+1
-1
No files found.
examples/research_projects/deebert/run_glue_deebert.py
View file @
5e8c8eb5
...
@@ -685,9 +685,9 @@ def main():
...
@@ -685,9 +685,9 @@ def main():
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
checkpoints
=
[
args
.
output_dir
]
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
for
checkpoint
in
checkpoints
:
...
@@ -725,7 +725,7 @@ def main():
...
@@ -725,7 +725,7 @@ def main():
for
i
in
range
(
model
.
num_layers
):
for
i
in
range
(
model
.
num_layers
):
info_str
+=
" {:.2f}"
.
format
(
100
*
each_layer_results
[
i
])
info_str
+=
" {:.2f}"
.
format
(
100
*
each_layer_results
[
i
])
logger
.
info
(
info_str
)
logger
.
info
(
info_str
)
result
=
dict
((
k
+
"_{}"
.
format
(
global_step
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
"_{}"
.
format
(
global_step
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
results
.
update
(
result
)
return
results
return
results
...
...
examples/research_projects/distillation/grouped_batch_sampler.py
View file @
5e8c8eb5
...
@@ -27,7 +27,7 @@ from utils import logger
...
@@ -27,7 +27,7 @@ from utils import logger
def
_quantize
(
x
,
bins
):
def
_quantize
(
x
,
bins
):
bins
=
copy
.
deepcopy
(
bins
)
bins
=
copy
.
deepcopy
(
bins
)
bins
=
sorted
(
bins
)
bins
=
sorted
(
bins
)
quantized
=
list
(
map
(
lambda
y
:
bisect
.
bisect_right
(
bins
,
y
)
,
x
))
quantized
=
[
bisect
.
bisect_right
(
bins
,
y
)
for
y
in
x
]
return
quantized
return
quantized
...
...
examples/research_projects/distillation/run_squad_w_distillation.py
View file @
5e8c8eb5
...
@@ -850,9 +850,9 @@ def main():
...
@@ -850,9 +850,9 @@ def main():
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
checkpoints
=
[
args
.
output_dir
]
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
...
@@ -865,7 +865,7 @@ def main():
...
@@ -865,7 +865,7 @@ def main():
# Evaluate
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
View file @
5e8c8eb5
...
@@ -247,9 +247,12 @@ class Trainer:
...
@@ -247,9 +247,12 @@ class Trainer:
lr
=
self
.
scheduler_fn
(
state_step
-
1
)
lr
=
self
.
scheduler_fn
(
state_step
-
1
)
eval_loss
=
self
.
evaluate
(
state
,
val_dataset
)
eval_loss
=
self
.
evaluate
(
state
,
val_dataset
)
logging_dict
=
dict
(
logging_dict
=
{
step
=
state_step
.
item
(),
eval_loss
=
eval_loss
.
item
(),
tr_loss
=
tr_loss
,
lr
=
lr
.
item
()
"step"
:
state_step
.
item
(),
)
"eval_loss"
:
eval_loss
.
item
(),
"tr_loss"
:
tr_loss
,
"lr"
:
lr
.
item
(),
}
tqdm
.
write
(
str
(
logging_dict
))
tqdm
.
write
(
str
(
logging_dict
))
self
.
logger
.
log
(
logging_dict
,
commit
=
True
)
self
.
logger
.
log
(
logging_dict
,
commit
=
True
)
...
...
examples/research_projects/jax-projects/big_bird/evaluate.py
View file @
5e8c8eb5
...
@@ -144,9 +144,9 @@ def main():
...
@@ -144,9 +144,9 @@ def main():
predictions
=
expand_to_aliases
(
example
[
"output"
])
predictions
=
expand_to_aliases
(
example
[
"output"
])
# some preprocessing to both prediction and answer
# some preprocessing to both prediction and answer
answers
=
set
([
""
.
join
(
a
.
split
())
for
a
in
answers
])
answers
=
{
""
.
join
(
a
.
split
())
for
a
in
answers
}
predictions
=
set
([
""
.
join
(
p
.
split
())
for
p
in
predictions
])
predictions
=
{
""
.
join
(
p
.
split
())
for
p
in
predictions
}
predictions
=
set
([
s
for
s
in
predictions
if
s
not
in
[
"``"
,
"''"
,
"`"
,
"'"
]
])
predictions
=
{
s
for
s
in
predictions
if
s
not
in
[
"``"
,
"''"
,
"`"
,
"'"
]
}
# if there is a common element, it's a exact match
# if there is a common element, it's a exact match
example
[
"match"
]
=
len
(
list
(
answers
&
predictions
))
>
0
example
[
"match"
]
=
len
(
list
(
answers
&
predictions
))
>
0
...
...
examples/research_projects/jax-projects/big_bird/prepare_natural_questions.py
View file @
5e8c8eb5
...
@@ -314,12 +314,12 @@ if __name__ == "__main__":
...
@@ -314,12 +314,12 @@ if __name__ == "__main__":
data
=
data
[
"train"
if
PROCESS_TRAIN
==
"true"
else
"validation"
]
data
=
data
[
"train"
if
PROCESS_TRAIN
==
"true"
else
"validation"
]
fn_kwargs
=
dict
(
fn_kwargs
=
{
tokenizer
=
tokenizer
,
"
tokenizer
"
:
tokenizer
,
doc_stride
=
DOC_STRIDE
,
"
doc_stride
"
:
DOC_STRIDE
,
max_length
=
MAX_LENGTH
,
"
max_length
"
:
MAX_LENGTH
,
assertion
=
False
,
"
assertion
"
:
False
,
)
}
data
=
data
.
map
(
prepare_inputs
,
fn_kwargs
=
fn_kwargs
)
data
=
data
.
map
(
prepare_inputs
,
fn_kwargs
=
fn_kwargs
)
data
=
data
.
remove_columns
([
"annotations"
,
"document"
,
"id"
,
"question"
])
data
=
data
.
remove_columns
([
"annotations"
,
"document"
,
"id"
,
"question"
])
print
(
data
)
print
(
data
)
...
...
examples/research_projects/jax-projects/model_parallel/partitions.py
View file @
5e8c8eb5
...
@@ -34,7 +34,7 @@ empty_dict = object()
...
@@ -34,7 +34,7 @@ empty_dict = object()
def
_match
(
qs
,
ks
):
def
_match
(
qs
,
ks
):
"""Return True if regexes in qs match any window of strings in tuple ks."""
"""Return True if regexes in qs match any window of strings in tuple ks."""
# compile regexes and force complete match
# compile regexes and force complete match
qts
=
tuple
(
map
(
lambda
x
:
re
.
compile
(
x
+
"$"
)
,
qs
))
qts
=
tuple
(
(
re
.
compile
(
x
+
"$"
)
for
x
in
qs
))
for
i
in
range
(
len
(
ks
)
-
len
(
qs
)
+
1
):
for
i
in
range
(
len
(
ks
)
-
len
(
qs
)
+
1
):
matches
=
[
x
.
match
(
y
)
for
x
,
y
in
zip
(
qts
,
ks
[
i
:])]
matches
=
[
x
.
match
(
y
)
for
x
,
y
in
zip
(
qts
,
ks
[
i
:])]
if
matches
and
all
(
matches
):
if
matches
and
all
(
matches
):
...
...
examples/research_projects/longform-qa/eli5_utils.py
View file @
5e8c8eb5
...
@@ -78,7 +78,7 @@ def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_1
...
@@ -78,7 +78,7 @@ def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_1
)
)
hits
=
response
[
"hits"
][
"hits"
]
hits
=
response
[
"hits"
][
"hits"
]
support_doc
=
"<P> "
+
" <P> "
.
join
([
hit
[
"_source"
][
"passage_text"
]
for
hit
in
hits
])
support_doc
=
"<P> "
+
" <P> "
.
join
([
hit
[
"_source"
][
"passage_text"
]
for
hit
in
hits
])
res_list
=
[
dict
([(
k
,
hit
[
"_source"
][
k
]
)
for
k
in
hit
[
"_source"
]
if
k
!=
"passage_text"
])
for
hit
in
hits
]
res_list
=
[
{
k
:
hit
[
"_source"
][
k
]
for
k
in
hit
[
"_source"
]
if
k
!=
"passage_text"
}
for
hit
in
hits
]
for
r
,
hit
in
zip
(
res_list
,
hits
):
for
r
,
hit
in
zip
(
res_list
,
hits
):
r
[
"passage_id"
]
=
hit
[
"_id"
]
r
[
"passage_id"
]
=
hit
[
"_id"
]
r
[
"score"
]
=
hit
[
"_score"
]
r
[
"score"
]
=
hit
[
"_score"
]
...
@@ -601,7 +601,7 @@ def make_qa_dense_index(
...
@@ -601,7 +601,7 @@ def make_qa_dense_index(
fp
=
np
.
memmap
(
index_name
,
dtype
=
dtype
,
mode
=
"w+"
,
shape
=
(
passages_dset
.
num_rows
,
128
))
fp
=
np
.
memmap
(
index_name
,
dtype
=
dtype
,
mode
=
"w+"
,
shape
=
(
passages_dset
.
num_rows
,
128
))
n_batches
=
math
.
ceil
(
passages_dset
.
num_rows
/
batch_size
)
n_batches
=
math
.
ceil
(
passages_dset
.
num_rows
/
batch_size
)
for
i
in
range
(
n_batches
):
for
i
in
range
(
n_batches
):
passages
=
[
p
for
p
in
passages_dset
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
][
"passage_text"
]
]
passages
=
list
(
passages_dset
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
][
"passage_text"
]
)
reps
=
embed_passages_for_retrieval
(
passages
,
tokenizer
,
qa_embedder
,
max_length
,
device
)
reps
=
embed_passages_for_retrieval
(
passages
,
tokenizer
,
qa_embedder
,
max_length
,
device
)
fp
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
]
=
reps
fp
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
]
=
reps
if
i
%
50
==
0
:
if
i
%
50
==
0
:
...
@@ -634,7 +634,7 @@ def query_qa_dense_index(
...
@@ -634,7 +634,7 @@ def query_qa_dense_index(
D
,
I
=
wiki_index
.
search
(
q_rep
,
2
*
n_results
)
D
,
I
=
wiki_index
.
search
(
q_rep
,
2
*
n_results
)
res_passages
=
[
wiki_passages
[
int
(
i
)]
for
i
in
I
[
0
]]
res_passages
=
[
wiki_passages
[
int
(
i
)]
for
i
in
I
[
0
]]
support_doc
=
"<P> "
+
" <P> "
.
join
([
p
[
"passage_text"
]
for
p
in
res_passages
])
support_doc
=
"<P> "
+
" <P> "
.
join
([
p
[
"passage_text"
]
for
p
in
res_passages
])
res_list
=
[
dict
([(
k
,
p
[
k
]
)
for
k
in
wiki_passages
.
column_names
])
for
p
in
res_passages
]
res_list
=
[
{
k
:
p
[
k
]
for
k
in
wiki_passages
.
column_names
}
for
p
in
res_passages
]
res_list
=
[
res
for
res
in
res_list
if
len
(
res
[
"passage_text"
].
split
())
>
min_length
][:
n_results
]
res_list
=
[
res
for
res
in
res_list
if
len
(
res
[
"passage_text"
].
split
())
>
min_length
][:
n_results
]
for
r
,
sc
in
zip
(
res_list
,
D
[
0
]):
for
r
,
sc
in
zip
(
res_list
,
D
[
0
]):
r
[
"score"
]
=
float
(
sc
)
r
[
"score"
]
=
float
(
sc
)
...
@@ -650,7 +650,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
...
@@ -650,7 +650,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
]
]
all_res_lists
=
[]
all_res_lists
=
[]
for
res_passages
,
dl
in
zip
(
res_passages_lst
,
D
):
for
res_passages
,
dl
in
zip
(
res_passages_lst
,
D
):
res_list
=
[
dict
([(
k
,
p
[
k
]
)
for
k
in
wiki_passages
.
column_names
])
for
p
in
res_passages
]
res_list
=
[
{
k
:
p
[
k
]
for
k
in
wiki_passages
.
column_names
}
for
p
in
res_passages
]
for
r
,
sc
in
zip
(
res_list
,
dl
):
for
r
,
sc
in
zip
(
res_list
,
dl
):
r
[
"score"
]
=
float
(
sc
)
r
[
"score"
]
=
float
(
sc
)
all_res_lists
+=
[
res_list
[:]]
all_res_lists
+=
[
res_list
[:]]
...
@@ -663,7 +663,7 @@ def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki
...
@@ -663,7 +663,7 @@ def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki
D
,
I
=
wiki_index
.
search
(
a_rep
,
2
*
n_results
)
D
,
I
=
wiki_index
.
search
(
a_rep
,
2
*
n_results
)
res_passages
=
[
wiki_passages
[
int
(
i
)]
for
i
in
I
[
0
]]
res_passages
=
[
wiki_passages
[
int
(
i
)]
for
i
in
I
[
0
]]
support_doc
=
"<P> "
+
" <P> "
.
join
([
p
[
"passage_text"
]
for
p
in
res_passages
])
support_doc
=
"<P> "
+
" <P> "
.
join
([
p
[
"passage_text"
]
for
p
in
res_passages
])
res_list
=
[
dict
([(
k
,
p
[
k
]
)
for
k
in
wiki_passages
.
column_names
])
for
p
in
res_passages
]
res_list
=
[
{
k
:
p
[
k
]
for
k
in
wiki_passages
.
column_names
}
for
p
in
res_passages
]
res_list
=
[
res
for
res
in
res_list
if
len
(
res
[
"passage_text"
].
split
())
>
min_length
][:
n_results
]
res_list
=
[
res
for
res
in
res_list
if
len
(
res
[
"passage_text"
].
split
())
>
min_length
][:
n_results
]
for
r
,
sc
,
i
in
zip
(
res_list
,
D
[
0
],
I
[
0
]):
for
r
,
sc
,
i
in
zip
(
res_list
,
D
[
0
],
I
[
0
]):
r
[
"passage_id"
]
=
int
(
i
)
r
[
"passage_id"
]
=
int
(
i
)
...
@@ -680,7 +680,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
...
@@ -680,7 +680,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
]
]
all_res_lists
=
[]
all_res_lists
=
[]
for
res_passages
,
dl
,
il
in
zip
(
res_passages_lst
,
D
,
I
):
for
res_passages
,
dl
,
il
in
zip
(
res_passages_lst
,
D
,
I
):
res_list
=
[
dict
([(
k
,
p
[
k
]
)
for
k
in
wiki_passages
.
column_names
])
for
p
in
res_passages
]
res_list
=
[
{
k
:
p
[
k
]
for
k
in
wiki_passages
.
column_names
}
for
p
in
res_passages
]
for
r
,
sc
,
i
in
zip
(
res_list
,
dl
,
il
):
for
r
,
sc
,
i
in
zip
(
res_list
,
dl
,
il
):
r
[
"passage_id"
]
=
int
(
i
)
r
[
"passage_id"
]
=
int
(
i
)
r
[
"score"
]
=
float
(
sc
)
r
[
"score"
]
=
float
(
sc
)
...
...
examples/research_projects/lxmert/extracting_data.py
View file @
5e8c8eb5
...
@@ -61,7 +61,7 @@ class Extract:
...
@@ -61,7 +61,7 @@ class Extract:
assert
outputfile
is
not
None
and
not
os
.
path
.
isfile
(
outputfile
),
f
"
{
outputfile
}
"
assert
outputfile
is
not
None
and
not
os
.
path
.
isfile
(
outputfile
),
f
"
{
outputfile
}
"
if
subset_list
is
not
None
:
if
subset_list
is
not
None
:
with
open
(
os
.
path
.
realpath
(
subset_list
))
as
f
:
with
open
(
os
.
path
.
realpath
(
subset_list
))
as
f
:
self
.
subset_list
=
set
(
map
(
lambda
x
:
self
.
_vqa_file_split
()[
0
]
,
tryload
(
f
)
))
self
.
subset_list
=
{
self
.
_vqa_file_split
()[
0
]
for
x
in
tryload
(
f
)
}
else
:
else
:
self
.
subset_list
=
None
self
.
subset_list
=
None
...
...
examples/research_projects/lxmert/modeling_frcnn.py
View file @
5e8c8eb5
...
@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module):
...
@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module):
Returns:
Returns:
A tensor of shape(N*B, Channels, output_size, output_size)
A tensor of shape(N*B, Channels, output_size, output_size)
"""
"""
x
=
[
v
for
v
in
feature_maps
.
values
()
]
x
=
list
(
feature_maps
.
values
()
)
num_level_assignments
=
len
(
self
.
level_poolers
)
num_level_assignments
=
len
(
self
.
level_poolers
)
assert
len
(
x
)
==
num_level_assignments
and
len
(
boxes
)
==
x
[
0
].
size
(
0
)
assert
len
(
x
)
==
num_level_assignments
and
len
(
boxes
)
==
x
[
0
].
size
(
0
)
...
...
examples/research_projects/mm-imdb/run_mmimdb.py
View file @
5e8c8eb5
...
@@ -554,9 +554,9 @@ def main():
...
@@ -554,9 +554,9 @@ def main():
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
checkpoints
=
[
args
.
output_dir
]
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
for
checkpoint
in
checkpoints
:
...
@@ -566,7 +566,7 @@ def main():
...
@@ -566,7 +566,7 @@ def main():
model
.
load_state_dict
(
torch
.
load
(
checkpoint
))
model
.
load_state_dict
(
torch
.
load
(
checkpoint
))
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
criterion
,
prefix
=
prefix
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
criterion
,
prefix
=
prefix
)
result
=
dict
((
k
+
"_{}"
.
format
(
global_step
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
"_{}"
.
format
(
global_step
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
results
.
update
(
result
)
return
results
return
results
...
...
examples/research_projects/movement-pruning/masked_run_glue.py
View file @
5e8c8eb5
...
@@ -941,9 +941,9 @@ def main():
...
@@ -941,9 +941,9 @@ def main():
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
checkpoints
=
[
args
.
output_dir
]
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
for
checkpoint
in
checkpoints
:
...
@@ -953,7 +953,7 @@ def main():
...
@@ -953,7 +953,7 @@ def main():
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
prefix
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
prefix
)
result
=
dict
((
k
+
"_{}"
.
format
(
global_step
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
"_{}"
.
format
(
global_step
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
results
.
update
(
result
)
return
results
return
results
...
...
examples/research_projects/movement-pruning/masked_run_squad.py
View file @
5e8c8eb5
...
@@ -1109,10 +1109,10 @@ def main():
...
@@ -1109,10 +1109,10 @@ def main():
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
checkpoints
=
[
args
.
output_dir
]
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
else
:
else
:
logger
.
info
(
"Loading checkpoint %s for evaluation"
,
args
.
model_name_or_path
)
logger
.
info
(
"Loading checkpoint %s for evaluation"
,
args
.
model_name_or_path
)
...
@@ -1129,7 +1129,7 @@ def main():
...
@@ -1129,7 +1129,7 @@ def main():
# Evaluate
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
examples/research_projects/onnx/summarization/bart_onnx/reduce_onnx_size.py
View file @
5e8c8eb5
...
@@ -42,8 +42,8 @@ def _graph_replace_input_with(graph_proto, name, new_name):
...
@@ -42,8 +42,8 @@ def _graph_replace_input_with(graph_proto, name, new_name):
def
_remove_dup_initializers_from_model
(
model
,
model_without_ext
,
ind_to_replace
):
def
_remove_dup_initializers_from_model
(
model
,
model_without_ext
,
ind_to_replace
):
inits_with_data
=
[
i
for
i
in
model
.
graph
.
initializer
]
inits_with_data
=
list
(
model
.
graph
.
initializer
)
inits
=
[
i
for
i
in
model_without_ext
.
graph
.
initializer
]
inits
=
list
(
model_without_ext
.
graph
.
initializer
)
for
i
,
ref_i
in
ind_to_replace
:
for
i
,
ref_i
in
ind_to_replace
:
assert
inits_with_data
[
i
].
name
==
inits
[
i
].
name
assert
inits_with_data
[
i
].
name
==
inits
[
i
].
name
assert
inits_with_data
[
ref_i
].
name
==
inits
[
ref_i
].
name
assert
inits_with_data
[
ref_i
].
name
==
inits
[
ref_i
].
name
...
@@ -69,7 +69,7 @@ def remove_dup_initializers(onnx_file_path):
...
@@ -69,7 +69,7 @@ def remove_dup_initializers(onnx_file_path):
model
=
onnx
.
load
(
os
.
path
.
join
(
model_file_folder
,
model_file_name
))
model
=
onnx
.
load
(
os
.
path
.
join
(
model_file_folder
,
model_file_name
))
inits
=
[
i
for
i
in
model
.
graph
.
initializer
]
inits
=
list
(
model
.
graph
.
initializer
)
dup_set
=
set
()
dup_set
=
set
()
dup_map
=
{}
dup_map
=
{}
...
...
examples/research_projects/pplm/run_pplm.py
View file @
5e8c8eb5
...
@@ -127,11 +127,9 @@ def perturb_past(
...
@@ -127,11 +127,9 @@ def perturb_past(
_
,
_
,
_
,
curr_length
,
_
=
past
[
0
].
shape
_
,
_
,
_
,
curr_length
,
_
=
past
[
0
].
shape
if
curr_length
>
window_length
and
window_length
>
0
:
if
curr_length
>
window_length
and
window_length
>
0
:
ones_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
window_length
]
)
+
tuple
(
past
[
0
].
shape
[
-
1
:])
ones_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
(
window_length
,
)
+
tuple
(
past
[
0
].
shape
[
-
1
:])
zeros_key_val_shape
=
(
zeros_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
(
curr_length
-
window_length
,)
+
tuple
(
past
[
0
].
shape
[
-
1
:])
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
curr_length
-
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
...
...
examples/research_projects/rag-end2end-retriever/finetune_rag.py
View file @
5e8c8eb5
...
@@ -164,11 +164,11 @@ class GenerativeQAModule(BaseTransformer):
...
@@ -164,11 +164,11 @@ class GenerativeQAModule(BaseTransformer):
self
.
step_count
=
0
self
.
step_count
=
0
self
.
metrics
=
defaultdict
(
list
)
self
.
metrics
=
defaultdict
(
list
)
self
.
dataset_kwargs
:
dict
=
dict
(
self
.
dataset_kwargs
:
dict
=
{
data_dir
=
self
.
hparams
.
data_dir
,
"
data_dir
"
:
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
"
max_source_length
"
:
self
.
hparams
.
max_source_length
,
prefix
=
prefix
or
""
,
"
prefix
"
:
prefix
or
""
,
)
}
n_observations_per_split
=
{
n_observations_per_split
=
{
"train"
:
self
.
hparams
.
n_train
,
"train"
:
self
.
hparams
.
n_train
,
"val"
:
self
.
hparams
.
n_val
,
"val"
:
self
.
hparams
.
n_val
,
...
...
examples/research_projects/rag-end2end-retriever/utils_rag.py
View file @
5e8c8eb5
...
@@ -137,7 +137,7 @@ logger = getLogger(__name__)
...
@@ -137,7 +137,7 @@ logger = getLogger(__name__)
def
flatten_list
(
summary_ids
:
List
[
List
]):
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)
]
return
list
(
itertools
.
chain
.
from_iterable
(
summary_ids
)
)
def
save_git_info
(
folder_path
:
str
)
->
None
:
def
save_git_info
(
folder_path
:
str
)
->
None
:
...
...
examples/research_projects/rag/finetune_rag.py
View file @
5e8c8eb5
...
@@ -162,11 +162,11 @@ class GenerativeQAModule(BaseTransformer):
...
@@ -162,11 +162,11 @@ class GenerativeQAModule(BaseTransformer):
self
.
step_count
=
0
self
.
step_count
=
0
self
.
metrics
=
defaultdict
(
list
)
self
.
metrics
=
defaultdict
(
list
)
self
.
dataset_kwargs
:
dict
=
dict
(
self
.
dataset_kwargs
:
dict
=
{
data_dir
=
self
.
hparams
.
data_dir
,
"
data_dir
"
:
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
"
max_source_length
"
:
self
.
hparams
.
max_source_length
,
prefix
=
prefix
or
""
,
"
prefix
"
:
prefix
or
""
,
)
}
n_observations_per_split
=
{
n_observations_per_split
=
{
"train"
:
self
.
hparams
.
n_train
,
"train"
:
self
.
hparams
.
n_train
,
"val"
:
self
.
hparams
.
n_val
,
"val"
:
self
.
hparams
.
n_val
,
...
...
examples/research_projects/rag/utils_rag.py
View file @
5e8c8eb5
...
@@ -137,7 +137,7 @@ logger = getLogger(__name__)
...
@@ -137,7 +137,7 @@ logger = getLogger(__name__)
def
flatten_list
(
summary_ids
:
List
[
List
]):
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)
]
return
list
(
itertools
.
chain
.
from_iterable
(
summary_ids
)
)
def
save_git_info
(
folder_path
:
str
)
->
None
:
def
save_git_info
(
folder_path
:
str
)
->
None
:
...
...
examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
View file @
5e8c8eb5
...
@@ -344,7 +344,7 @@ def create_vocabulary_from_data(
...
@@ -344,7 +344,7 @@ def create_vocabulary_from_data(
lambda
vocab_1
,
vocab_2
:
set
(
vocab_1
[
"vocab"
][
0
])
|
set
(
vocab_2
[
"vocab"
][
0
]),
vocabs
.
values
()
lambda
vocab_1
,
vocab_2
:
set
(
vocab_1
[
"vocab"
][
0
])
|
set
(
vocab_2
[
"vocab"
][
0
]),
vocabs
.
values
()
)
)
vocab_dict
=
{
v
:
k
for
k
,
v
in
enumerate
(
sorted
(
list
(
vocab_set
))
)
}
vocab_dict
=
{
v
:
k
for
k
,
v
in
enumerate
(
sorted
(
vocab_set
))}
# replace white space with delimiter token
# replace white space with delimiter token
if
word_delimiter_token
is
not
None
:
if
word_delimiter_token
is
not
None
:
...
...
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment