Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8cbe7d6a
Commit
8cbe7d6a
authored
Nov 02, 2018
by
VictorSanh
Browse files
FIX errors in loading eval Dataset in `run_squad_pytorch`
parent
833c3a7a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
run_squad_pytorch.py
run_squad_pytorch.py
+5
-3
No files found.
run_squad_pytorch.py
View file @
8cbe7d6a
...
@@ -865,10 +865,11 @@ def main():
...
@@ -865,10 +865,11 @@ def main():
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
long
)
#
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
,
all_example_index
)
#eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
)
if
args
.
local_rank
==
-
1
:
if
args
.
local_rank
==
-
1
:
eval_sampler
=
SequentialSampler
(
eval_data
)
eval_sampler
=
SequentialSampler
(
eval_data
)
else
:
else
:
...
@@ -877,7 +878,8 @@ def main():
...
@@ -877,7 +878,8 @@ def main():
model
.
eval
()
model
.
eval
()
all_results
=
[]
all_results
=
[]
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
,
example_index
in
eval_dataloader
:
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
for
input_ids
,
input_mask
,
segment_ids
,
example_index
in
eval_dataloader
:
if
len
(
all_results
)
%
1000
==
0
:
if
len
(
all_results
)
%
1000
==
0
:
logger
.
info
(
"Processing example: %d"
%
(
len
(
all_results
)))
logger
.
info
(
"Processing example: %d"
%
(
len
(
all_results
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment