Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9ddc3f1a
Commit
9ddc3f1a
authored
Dec 04, 2019
by
LysandreJik
Browse files
Naming update + XLNet/XLM evaluation
parent
de276de1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
18 deletions
+85
-18
examples/run_squad.py
examples/run_squad.py
+3
-3
transformers/data/metrics/squad_metrics.py
transformers/data/metrics/squad_metrics.py
+82
-15
No files found.
examples/run_squad.py
View file @
9ddc3f1a
...
...
@@ -17,7 +17,7 @@
from
__future__
import
absolute_import
,
division
,
print_function
from
transformers.data.processors.squad
import
SquadV1Processor
,
SquadV2Processor
,
SquadResult
from
transformers.data.metrics.squad_metrics
import
compute_predictions
,
compute_predictions_
extended
,
squad_evaluate
from
transformers.data.metrics.squad_metrics
import
compute_predictions
_logits
,
compute_predictions_
log_probs
,
squad_evaluate
import
argparse
import
logging
...
...
@@ -264,13 +264,13 @@ def evaluate(args, model, tokenizer, prefix=""):
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
# XLNet uses a more complex post-processing procedure
predictions
=
compute_predictions_
extended
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compute_predictions_
log_probs
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
predict_file
,
model
.
config
.
start_n_top
,
model
.
config
.
end_n_top
,
args
.
version_2_with_negative
,
tokenizer
,
args
.
verbose_logging
)
else
:
predictions
=
compute_predictions
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compute_predictions
_logits
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
...
...
transformers/data/metrics/squad_metrics.py
View file @
9ddc3f1a
...
...
@@ -125,6 +125,53 @@ def merge_eval(main_eval, new_eval, prefix):
main_eval
[
'%s_%s'
%
(
prefix
,
k
)]
=
new_eval
[
k
]
def
find_best_thresh_v2
(
preds
,
scores
,
na_probs
,
qid_to_has_ans
):
num_no_ans
=
sum
(
1
for
k
in
qid_to_has_ans
if
not
qid_to_has_ans
[
k
])
cur_score
=
num_no_ans
best_score
=
cur_score
best_thresh
=
0.0
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
for
i
,
qid
in
enumerate
(
qid_list
):
if
qid
not
in
scores
:
continue
if
qid_to_has_ans
[
qid
]:
diff
=
scores
[
qid
]
else
:
if
preds
[
qid
]:
diff
=
-
1
else
:
diff
=
0
cur_score
+=
diff
if
cur_score
>
best_score
:
best_score
=
cur_score
best_thresh
=
na_probs
[
qid
]
has_ans_score
,
has_ans_cnt
=
0
,
0
for
qid
in
qid_list
:
if
not
qid_to_has_ans
[
qid
]:
continue
has_ans_cnt
+=
1
if
qid
not
in
scores
:
continue
has_ans_score
+=
scores
[
qid
]
return
100.0
*
best_score
/
len
(
scores
),
best_thresh
,
1.0
*
has_ans_score
/
has_ans_cnt
def
find_all_best_thresh_v2
(
main_eval
,
preds
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
best_exact
,
exact_thresh
,
has_ans_exact
=
find_best_thresh_v2
(
preds
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
,
has_ans_f1
=
find_best_thresh_v2
(
preds
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
'best_exact'
]
=
best_exact
main_eval
[
'best_exact_thresh'
]
=
exact_thresh
main_eval
[
'best_f1'
]
=
best_f1
main_eval
[
'best_f1_thresh'
]
=
f1_thresh
main_eval
[
'has_ans_exact'
]
=
has_ans_exact
main_eval
[
'has_ans_f1'
]
=
has_ans_f1
def
find_best_thresh
(
preds
,
scores
,
na_probs
,
qid_to_has_ans
):
num_no_ans
=
sum
(
1
for
k
in
qid_to_has_ans
if
not
qid_to_has_ans
[
k
])
cur_score
=
num_no_ans
...
...
@@ -318,10 +365,20 @@ def _compute_softmax(scores):
return
probs
def
compute_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
verbose_logging
,
version_2_with_negative
,
null_score_diff_threshold
):
def
compute_predictions_logits
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
verbose_logging
,
version_2_with_negative
,
null_score_diff_threshold
):
"""Write final predictions to the json file and log-odds of null if needed."""
logger
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
logger
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
...
...
@@ -453,7 +510,7 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
# In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure.
if
len
(
nbest
)
==
1
:
if
len
(
nbest
)
==
1
:
nbest
.
insert
(
0
,
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
...
...
@@ -512,12 +569,22 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
return
all_predictions
def
compute_predictions_extended
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
output_prediction_file
,
def
compute_predictions_log_probs
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
orig_data_file
,
start_n_top
,
end_n_top
,
version_2_with_negative
,
tokenizer
,
verbose_logging
):
output_null_log_odds_file
,
orig_data_file
,
start_n_top
,
end_n_top
,
version_2_with_negative
,
tokenizer
,
verbose_logging
):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment