Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a1126237
Commit
a1126237
authored
Nov 06, 2018
by
thomwolf
Browse files
clean up logits extraction logic
parent
2a97fe22
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
16 deletions
+13
-16
run_squad.py
run_squad.py
+13
-16
No files found.
run_squad.py
View file @
a1126237
...
@@ -908,7 +908,7 @@ def main():
...
@@ -908,7 +908,7 @@ def main():
model
.
eval
()
model
.
eval
()
all_results
=
[]
all_results
=
[]
logger
.
info
(
"Start evaluating"
)
logger
.
info
(
"Start evaluating"
)
for
input_ids
,
input_mask
,
segment_ids
,
example_ind
ex
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
for
input_ids
,
input_mask
,
segment_ids
,
example_ind
ices
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
if
len
(
all_results
)
%
1000
==
0
:
if
len
(
all_results
)
%
1000
==
0
:
logger
.
info
(
"Processing example: %d"
%
(
len
(
all_results
)))
logger
.
info
(
"Processing example: %d"
%
(
len
(
all_results
)))
...
@@ -916,21 +916,18 @@ def main():
...
@@ -916,21 +916,18 @@ def main():
input_mask
=
input_mask
.
to
(
device
)
input_mask
=
input_mask
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
start_logits
,
end_logits
=
model
(
input_ids
,
segment_ids
,
input_mask
)
with
torch
.
no_grad
():
batch_start_logits
,
batch_end_logits
=
model
(
input_ids
,
segment_ids
,
input_mask
)
unique_id
=
[
int
(
eval_features
[
e
.
item
()].
unique_id
)
for
e
in
example_index
]
start_logits
=
[
x
.
view
(
-
1
).
detach
().
cpu
().
numpy
()
for
x
in
start_logits
]
for
i
,
example_index
in
enumerate
(
example_indices
):
end_logits
=
[
x
.
view
(
-
1
).
detach
().
cpu
().
numpy
()
for
x
in
end_logits
]
start_logits
=
batch_start_logits
[
i
].
detach
().
cpu
().
tolist
()
for
idx
,
i
in
enumerate
(
unique_id
):
end_logits
=
batch_end_logits
[
i
].
detach
().
cpu
().
tolist
()
s
=
[
float
(
x
)
for
x
in
start_logits
[
idx
]]
e
=
[
float
(
x
)
for
x
in
end_logits
[
idx
]]
eval_feature
=
eval_features
[
example_index
.
item
()]
all_results
.
append
(
unique_id
=
int
(
eval_feature
.
unique_id
)
RawResult
(
all_results
.
append
(
RawResult
(
unique_id
=
unique_id
,
unique_id
=
i
,
start_logits
=
start_logits
,
start_logits
=
s
,
end_logits
=
end_logits
))
end_logits
=
e
)
)
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions.json"
)
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions.json"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment