Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
58 additions
and
57 deletions
+58
-57
examples/research_projects/deebert/run_glue_deebert.py
examples/research_projects/deebert/run_glue_deebert.py
+3
-3
examples/research_projects/distillation/grouped_batch_sampler.py
...s/research_projects/distillation/grouped_batch_sampler.py
+1
-1
examples/research_projects/distillation/run_squad_w_distillation.py
...esearch_projects/distillation/run_squad_w_distillation.py
+3
-3
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
...s/research_projects/jax-projects/big_bird/bigbird_flax.py
+6
-3
examples/research_projects/jax-projects/big_bird/evaluate.py
examples/research_projects/jax-projects/big_bird/evaluate.py
+3
-3
examples/research_projects/jax-projects/big_bird/prepare_natural_questions.py
...ojects/jax-projects/big_bird/prepare_natural_questions.py
+6
-6
examples/research_projects/jax-projects/model_parallel/partitions.py
...search_projects/jax-projects/model_parallel/partitions.py
+1
-1
examples/research_projects/longform-qa/eli5_utils.py
examples/research_projects/longform-qa/eli5_utils.py
+6
-6
examples/research_projects/lxmert/extracting_data.py
examples/research_projects/lxmert/extracting_data.py
+1
-1
examples/research_projects/lxmert/modeling_frcnn.py
examples/research_projects/lxmert/modeling_frcnn.py
+1
-1
examples/research_projects/mm-imdb/run_mmimdb.py
examples/research_projects/mm-imdb/run_mmimdb.py
+3
-3
examples/research_projects/movement-pruning/masked_run_glue.py
...les/research_projects/movement-pruning/masked_run_glue.py
+3
-3
examples/research_projects/movement-pruning/masked_run_squad.py
...es/research_projects/movement-pruning/masked_run_squad.py
+3
-3
examples/research_projects/onnx/summarization/bart_onnx/reduce_onnx_size.py
...projects/onnx/summarization/bart_onnx/reduce_onnx_size.py
+3
-3
examples/research_projects/pplm/run_pplm.py
examples/research_projects/pplm/run_pplm.py
+2
-4
examples/research_projects/rag-end2end-retriever/finetune_rag.py
...s/research_projects/rag-end2end-retriever/finetune_rag.py
+5
-5
examples/research_projects/rag-end2end-retriever/utils_rag.py
...ples/research_projects/rag-end2end-retriever/utils_rag.py
+1
-1
examples/research_projects/rag/finetune_rag.py
examples/research_projects/rag/finetune_rag.py
+5
-5
examples/research_projects/rag/utils_rag.py
examples/research_projects/rag/utils_rag.py
+1
-1
examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
...cts/robust-speech-event/run_speech_recognition_ctc_bnb.py
+1
-1
No files found.
examples/research_projects/deebert/run_glue_deebert.py
View file @
5e8c8eb5
...
...
@@ -685,9 +685,9 @@ def main():
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
...
...
@@ -725,7 +725,7 @@ def main():
for
i
in
range
(
model
.
num_layers
):
info_str
+=
" {:.2f}"
.
format
(
100
*
each_layer_results
[
i
])
logger
.
info
(
info_str
)
result
=
dict
((
k
+
"_{}"
.
format
(
global_step
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
"_{}"
.
format
(
global_step
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
return
results
...
...
examples/research_projects/distillation/grouped_batch_sampler.py
View file @
5e8c8eb5
...
...
@@ -27,7 +27,7 @@ from utils import logger
def
_quantize
(
x
,
bins
):
bins
=
copy
.
deepcopy
(
bins
)
bins
=
sorted
(
bins
)
quantized
=
list
(
map
(
lambda
y
:
bisect
.
bisect_right
(
bins
,
y
)
,
x
))
quantized
=
[
bisect
.
bisect_right
(
bins
,
y
)
for
y
in
x
]
return
quantized
...
...
examples/research_projects/distillation/run_squad_w_distillation.py
View file @
5e8c8eb5
...
...
@@ -850,9 +850,9 @@ def main():
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
...
...
@@ -865,7 +865,7 @@ def main():
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
View file @
5e8c8eb5
...
...
@@ -247,9 +247,12 @@ class Trainer:
lr
=
self
.
scheduler_fn
(
state_step
-
1
)
eval_loss
=
self
.
evaluate
(
state
,
val_dataset
)
logging_dict
=
dict
(
step
=
state_step
.
item
(),
eval_loss
=
eval_loss
.
item
(),
tr_loss
=
tr_loss
,
lr
=
lr
.
item
()
)
logging_dict
=
{
"step"
:
state_step
.
item
(),
"eval_loss"
:
eval_loss
.
item
(),
"tr_loss"
:
tr_loss
,
"lr"
:
lr
.
item
(),
}
tqdm
.
write
(
str
(
logging_dict
))
self
.
logger
.
log
(
logging_dict
,
commit
=
True
)
...
...
examples/research_projects/jax-projects/big_bird/evaluate.py
View file @
5e8c8eb5
...
...
@@ -144,9 +144,9 @@ def main():
predictions
=
expand_to_aliases
(
example
[
"output"
])
# some preprocessing to both prediction and answer
answers
=
set
([
""
.
join
(
a
.
split
())
for
a
in
answers
])
predictions
=
set
([
""
.
join
(
p
.
split
())
for
p
in
predictions
])
predictions
=
set
([
s
for
s
in
predictions
if
s
not
in
[
"``"
,
"''"
,
"`"
,
"'"
]
])
answers
=
{
""
.
join
(
a
.
split
())
for
a
in
answers
}
predictions
=
{
""
.
join
(
p
.
split
())
for
p
in
predictions
}
predictions
=
{
s
for
s
in
predictions
if
s
not
in
[
"``"
,
"''"
,
"`"
,
"'"
]
}
# if there is a common element, it's a exact match
example
[
"match"
]
=
len
(
list
(
answers
&
predictions
))
>
0
...
...
examples/research_projects/jax-projects/big_bird/prepare_natural_questions.py
View file @
5e8c8eb5
...
...
@@ -314,12 +314,12 @@ if __name__ == "__main__":
data
=
data
[
"train"
if
PROCESS_TRAIN
==
"true"
else
"validation"
]
fn_kwargs
=
dict
(
tokenizer
=
tokenizer
,
doc_stride
=
DOC_STRIDE
,
max_length
=
MAX_LENGTH
,
assertion
=
False
,
)
fn_kwargs
=
{
"
tokenizer
"
:
tokenizer
,
"
doc_stride
"
:
DOC_STRIDE
,
"
max_length
"
:
MAX_LENGTH
,
"
assertion
"
:
False
,
}
data
=
data
.
map
(
prepare_inputs
,
fn_kwargs
=
fn_kwargs
)
data
=
data
.
remove_columns
([
"annotations"
,
"document"
,
"id"
,
"question"
])
print
(
data
)
...
...
examples/research_projects/jax-projects/model_parallel/partitions.py
View file @
5e8c8eb5
...
...
@@ -34,7 +34,7 @@ empty_dict = object()
def
_match
(
qs
,
ks
):
"""Return True if regexes in qs match any window of strings in tuple ks."""
# compile regexes and force complete match
qts
=
tuple
(
map
(
lambda
x
:
re
.
compile
(
x
+
"$"
)
,
qs
))
qts
=
tuple
(
(
re
.
compile
(
x
+
"$"
)
for
x
in
qs
))
for
i
in
range
(
len
(
ks
)
-
len
(
qs
)
+
1
):
matches
=
[
x
.
match
(
y
)
for
x
,
y
in
zip
(
qts
,
ks
[
i
:])]
if
matches
and
all
(
matches
):
...
...
examples/research_projects/longform-qa/eli5_utils.py
View file @
5e8c8eb5
...
...
@@ -78,7 +78,7 @@ def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_1
)
hits
=
response
[
"hits"
][
"hits"
]
support_doc
=
"<P> "
+
" <P> "
.
join
([
hit
[
"_source"
][
"passage_text"
]
for
hit
in
hits
])
res_list
=
[
dict
([(
k
,
hit
[
"_source"
][
k
]
)
for
k
in
hit
[
"_source"
]
if
k
!=
"passage_text"
])
for
hit
in
hits
]
res_list
=
[
{
k
:
hit
[
"_source"
][
k
]
for
k
in
hit
[
"_source"
]
if
k
!=
"passage_text"
}
for
hit
in
hits
]
for
r
,
hit
in
zip
(
res_list
,
hits
):
r
[
"passage_id"
]
=
hit
[
"_id"
]
r
[
"score"
]
=
hit
[
"_score"
]
...
...
@@ -601,7 +601,7 @@ def make_qa_dense_index(
fp
=
np
.
memmap
(
index_name
,
dtype
=
dtype
,
mode
=
"w+"
,
shape
=
(
passages_dset
.
num_rows
,
128
))
n_batches
=
math
.
ceil
(
passages_dset
.
num_rows
/
batch_size
)
for
i
in
range
(
n_batches
):
passages
=
[
p
for
p
in
passages_dset
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
][
"passage_text"
]
]
passages
=
list
(
passages_dset
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
][
"passage_text"
]
)
reps
=
embed_passages_for_retrieval
(
passages
,
tokenizer
,
qa_embedder
,
max_length
,
device
)
fp
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
]
=
reps
if
i
%
50
==
0
:
...
...
@@ -634,7 +634,7 @@ def query_qa_dense_index(
D
,
I
=
wiki_index
.
search
(
q_rep
,
2
*
n_results
)
res_passages
=
[
wiki_passages
[
int
(
i
)]
for
i
in
I
[
0
]]
support_doc
=
"<P> "
+
" <P> "
.
join
([
p
[
"passage_text"
]
for
p
in
res_passages
])
res_list
=
[
dict
([(
k
,
p
[
k
]
)
for
k
in
wiki_passages
.
column_names
])
for
p
in
res_passages
]
res_list
=
[
{
k
:
p
[
k
]
for
k
in
wiki_passages
.
column_names
}
for
p
in
res_passages
]
res_list
=
[
res
for
res
in
res_list
if
len
(
res
[
"passage_text"
].
split
())
>
min_length
][:
n_results
]
for
r
,
sc
in
zip
(
res_list
,
D
[
0
]):
r
[
"score"
]
=
float
(
sc
)
...
...
@@ -650,7 +650,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
]
all_res_lists
=
[]
for
res_passages
,
dl
in
zip
(
res_passages_lst
,
D
):
res_list
=
[
dict
([(
k
,
p
[
k
]
)
for
k
in
wiki_passages
.
column_names
])
for
p
in
res_passages
]
res_list
=
[
{
k
:
p
[
k
]
for
k
in
wiki_passages
.
column_names
}
for
p
in
res_passages
]
for
r
,
sc
in
zip
(
res_list
,
dl
):
r
[
"score"
]
=
float
(
sc
)
all_res_lists
+=
[
res_list
[:]]
...
...
@@ -663,7 +663,7 @@ def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki
D
,
I
=
wiki_index
.
search
(
a_rep
,
2
*
n_results
)
res_passages
=
[
wiki_passages
[
int
(
i
)]
for
i
in
I
[
0
]]
support_doc
=
"<P> "
+
" <P> "
.
join
([
p
[
"passage_text"
]
for
p
in
res_passages
])
res_list
=
[
dict
([(
k
,
p
[
k
]
)
for
k
in
wiki_passages
.
column_names
])
for
p
in
res_passages
]
res_list
=
[
{
k
:
p
[
k
]
for
k
in
wiki_passages
.
column_names
}
for
p
in
res_passages
]
res_list
=
[
res
for
res
in
res_list
if
len
(
res
[
"passage_text"
].
split
())
>
min_length
][:
n_results
]
for
r
,
sc
,
i
in
zip
(
res_list
,
D
[
0
],
I
[
0
]):
r
[
"passage_id"
]
=
int
(
i
)
...
...
@@ -680,7 +680,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
]
all_res_lists
=
[]
for
res_passages
,
dl
,
il
in
zip
(
res_passages_lst
,
D
,
I
):
res_list
=
[
dict
([(
k
,
p
[
k
]
)
for
k
in
wiki_passages
.
column_names
])
for
p
in
res_passages
]
res_list
=
[
{
k
:
p
[
k
]
for
k
in
wiki_passages
.
column_names
}
for
p
in
res_passages
]
for
r
,
sc
,
i
in
zip
(
res_list
,
dl
,
il
):
r
[
"passage_id"
]
=
int
(
i
)
r
[
"score"
]
=
float
(
sc
)
...
...
examples/research_projects/lxmert/extracting_data.py
View file @
5e8c8eb5
...
...
@@ -61,7 +61,7 @@ class Extract:
assert
outputfile
is
not
None
and
not
os
.
path
.
isfile
(
outputfile
),
f
"
{
outputfile
}
"
if
subset_list
is
not
None
:
with
open
(
os
.
path
.
realpath
(
subset_list
))
as
f
:
self
.
subset_list
=
set
(
map
(
lambda
x
:
self
.
_vqa_file_split
()[
0
]
,
tryload
(
f
)
))
self
.
subset_list
=
{
self
.
_vqa_file_split
()[
0
]
for
x
in
tryload
(
f
)
}
else
:
self
.
subset_list
=
None
...
...
examples/research_projects/lxmert/modeling_frcnn.py
View file @
5e8c8eb5
...
...
@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module):
Returns:
A tensor of shape(N*B, Channels, output_size, output_size)
"""
x
=
[
v
for
v
in
feature_maps
.
values
()
]
x
=
list
(
feature_maps
.
values
()
)
num_level_assignments
=
len
(
self
.
level_poolers
)
assert
len
(
x
)
==
num_level_assignments
and
len
(
boxes
)
==
x
[
0
].
size
(
0
)
...
...
examples/research_projects/mm-imdb/run_mmimdb.py
View file @
5e8c8eb5
...
...
@@ -554,9 +554,9 @@ def main():
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
...
...
@@ -566,7 +566,7 @@ def main():
model
.
load_state_dict
(
torch
.
load
(
checkpoint
))
model
.
to
(
args
.
device
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
criterion
,
prefix
=
prefix
)
result
=
dict
((
k
+
"_{}"
.
format
(
global_step
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
"_{}"
.
format
(
global_step
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
return
results
...
...
examples/research_projects/movement-pruning/masked_run_glue.py
View file @
5e8c8eb5
...
...
@@ -941,9 +941,9 @@ def main():
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
...
...
@@ -953,7 +953,7 @@ def main():
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
.
to
(
args
.
device
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
prefix
)
result
=
dict
((
k
+
"_{}"
.
format
(
global_step
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
"_{}"
.
format
(
global_step
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
return
results
...
...
examples/research_projects/movement-pruning/masked_run_squad.py
View file @
5e8c8eb5
...
...
@@ -1109,10 +1109,10 @@ def main():
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
[
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
"/**/"
+
WEIGHTS_NAME
,
recursive
=
True
))
)
]
else
:
logger
.
info
(
"Loading checkpoint %s for evaluation"
,
args
.
model_name_or_path
)
...
...
@@ -1129,7 +1129,7 @@ def main():
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
,
v
)
for
k
,
v
in
result
.
items
()
)
result
=
{
k
+
(
"_{}"
.
format
(
global_step
)
if
global_step
else
""
)
:
v
for
k
,
v
in
result
.
items
()
}
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
examples/research_projects/onnx/summarization/bart_onnx/reduce_onnx_size.py
View file @
5e8c8eb5
...
...
@@ -42,8 +42,8 @@ def _graph_replace_input_with(graph_proto, name, new_name):
def
_remove_dup_initializers_from_model
(
model
,
model_without_ext
,
ind_to_replace
):
inits_with_data
=
[
i
for
i
in
model
.
graph
.
initializer
]
inits
=
[
i
for
i
in
model_without_ext
.
graph
.
initializer
]
inits_with_data
=
list
(
model
.
graph
.
initializer
)
inits
=
list
(
model_without_ext
.
graph
.
initializer
)
for
i
,
ref_i
in
ind_to_replace
:
assert
inits_with_data
[
i
].
name
==
inits
[
i
].
name
assert
inits_with_data
[
ref_i
].
name
==
inits
[
ref_i
].
name
...
...
@@ -69,7 +69,7 @@ def remove_dup_initializers(onnx_file_path):
model
=
onnx
.
load
(
os
.
path
.
join
(
model_file_folder
,
model_file_name
))
inits
=
[
i
for
i
in
model
.
graph
.
initializer
]
inits
=
list
(
model
.
graph
.
initializer
)
dup_set
=
set
()
dup_map
=
{}
...
...
examples/research_projects/pplm/run_pplm.py
View file @
5e8c8eb5
...
...
@@ -127,11 +127,9 @@ def perturb_past(
_
,
_
,
_
,
curr_length
,
_
=
past
[
0
].
shape
if
curr_length
>
window_length
and
window_length
>
0
:
ones_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
window_length
]
)
+
tuple
(
past
[
0
].
shape
[
-
1
:])
ones_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
(
window_length
,
)
+
tuple
(
past
[
0
].
shape
[
-
1
:])
zeros_key_val_shape
=
(
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
curr_length
-
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
zeros_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
(
curr_length
-
window_length
,)
+
tuple
(
past
[
0
].
shape
[
-
1
:])
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
...
...
examples/research_projects/rag-end2end-retriever/finetune_rag.py
View file @
5e8c8eb5
...
...
@@ -164,11 +164,11 @@ class GenerativeQAModule(BaseTransformer):
self
.
step_count
=
0
self
.
metrics
=
defaultdict
(
list
)
self
.
dataset_kwargs
:
dict
=
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
prefix
=
prefix
or
""
,
)
self
.
dataset_kwargs
:
dict
=
{
"
data_dir
"
:
self
.
hparams
.
data_dir
,
"
max_source_length
"
:
self
.
hparams
.
max_source_length
,
"
prefix
"
:
prefix
or
""
,
}
n_observations_per_split
=
{
"train"
:
self
.
hparams
.
n_train
,
"val"
:
self
.
hparams
.
n_val
,
...
...
examples/research_projects/rag-end2end-retriever/utils_rag.py
View file @
5e8c8eb5
...
...
@@ -137,7 +137,7 @@ logger = getLogger(__name__)
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)
]
return
list
(
itertools
.
chain
.
from_iterable
(
summary_ids
)
)
def
save_git_info
(
folder_path
:
str
)
->
None
:
...
...
examples/research_projects/rag/finetune_rag.py
View file @
5e8c8eb5
...
...
@@ -162,11 +162,11 @@ class GenerativeQAModule(BaseTransformer):
self
.
step_count
=
0
self
.
metrics
=
defaultdict
(
list
)
self
.
dataset_kwargs
:
dict
=
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
prefix
=
prefix
or
""
,
)
self
.
dataset_kwargs
:
dict
=
{
"
data_dir
"
:
self
.
hparams
.
data_dir
,
"
max_source_length
"
:
self
.
hparams
.
max_source_length
,
"
prefix
"
:
prefix
or
""
,
}
n_observations_per_split
=
{
"train"
:
self
.
hparams
.
n_train
,
"val"
:
self
.
hparams
.
n_val
,
...
...
examples/research_projects/rag/utils_rag.py
View file @
5e8c8eb5
...
...
@@ -137,7 +137,7 @@ logger = getLogger(__name__)
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)
]
return
list
(
itertools
.
chain
.
from_iterable
(
summary_ids
)
)
def
save_git_info
(
folder_path
:
str
)
->
None
:
...
...
examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
View file @
5e8c8eb5
...
...
@@ -344,7 +344,7 @@ def create_vocabulary_from_data(
lambda
vocab_1
,
vocab_2
:
set
(
vocab_1
[
"vocab"
][
0
])
|
set
(
vocab_2
[
"vocab"
][
0
]),
vocabs
.
values
()
)
vocab_dict
=
{
v
:
k
for
k
,
v
in
enumerate
(
sorted
(
list
(
vocab_set
))
)
}
vocab_dict
=
{
v
:
k
for
k
,
v
in
enumerate
(
sorted
(
vocab_set
))}
# replace white space with delimiter token
if
word_delimiter_token
is
not
None
:
...
...
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment