Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e0bf01d9
Commit
e0bf01d9
authored
Mar 16, 2019
by
Ananya Harsh Jha
Browse files
added hack for mismatched MNLI
parent
4c721c6b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
1 deletion
+65
-1
examples/run_classifier.py
examples/run_classifier.py
+65
-1
No files found.
examples/run_classifier.py
View file @
e0bf01d9
...
...
@@ -679,7 +679,6 @@ def main():
output_modes
=
{
"cola"
:
"classification"
,
"mnli"
:
"classification"
,
"mnli-mm"
:
"classification"
,
"mrpc"
:
"classification"
,
"sst-2"
:
"classification"
,
"sts-b"
:
"regression"
,
...
...
@@ -930,6 +929,8 @@ def main():
preds
=
preds
[
0
]
if
output_mode
==
"classification"
:
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
elif
output_mode
==
"regression"
:
preds
=
np
.
squeeze
(
preds
)
result
=
compute_metrics
(
task_name
,
preds
,
all_label_ids
.
numpy
())
loss
=
tr_loss
/
nb_tr_steps
if
args
.
do_train
else
None
...
...
@@ -943,6 +944,69 @@ def main():
for
key
in
sorted
(
result
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
# hack for MNLI-MM
if
task_name
==
"mnli"
:
task_name
=
"mnli-mm"
processor
=
processors
[
task_name
]()
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
)
logger
.
info
(
"***** Running evaluation *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
long
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
# Run prediction for full data
eval_sampler
=
SequentialSampler
(
eval_data
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
model
.
eval
()
eval_loss
=
0
nb_eval_steps
=
0
preds
=
[]
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
input_ids
=
input_ids
.
to
(
device
)
input_mask
=
input_mask
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
with
torch
.
no_grad
():
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
labels
=
None
)
loss_fct
=
CrossEntropyLoss
()
tmp_eval_loss
=
loss_fct
(
logits
.
view
(
-
1
,
num_labels
),
label_ids
.
view
(
-
1
))
eval_loss
+=
tmp_eval_loss
.
mean
().
item
()
nb_eval_steps
+=
1
if
len
(
preds
)
==
0
:
preds
.
append
(
logits
.
detach
().
cpu
().
numpy
())
else
:
preds
[
0
]
=
np
.
append
(
preds
[
0
],
logits
.
detach
().
cpu
().
numpy
(),
axis
=
0
)
eval_loss
=
eval_loss
/
nb_eval_steps
preds
=
preds
[
0
]
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
result
=
compute_metrics
(
task_name
,
preds
,
all_label_ids
.
numpy
())
loss
=
tr_loss
/
nb_tr_steps
if
args
.
do_train
else
None
result
[
'eval_loss'
]
=
eval_loss
result
[
'global_step'
]
=
global_step
result
[
'loss'
]
=
loss
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
+
'-MM'
,
"eval_results.txt"
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results *****"
)
for
key
in
sorted
(
result
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment