Commit 9ddc3f1a authored by LysandreJik's avatar LysandreJik
Browse files

Naming update + XLNet/XLM evaluation

parent de276de1
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate
import argparse import argparse
import logging import logging
...@@ -264,13 +264,13 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -264,13 +264,13 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size, predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file, args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file, output_nbest_file, output_null_log_odds_file, args.predict_file,
model.config.start_n_top, model.config.end_n_top, model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.version_2_with_negative, tokenizer, args.verbose_logging)
else: else:
predictions = compute_predictions(examples, features, all_results, args.n_best_size, predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file, args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold) args.version_2_with_negative, args.null_score_diff_threshold)
......
...@@ -125,6 +125,53 @@ def merge_eval(main_eval, new_eval, prefix): ...@@ -125,6 +125,53 @@ def merge_eval(main_eval, new_eval, prefix):
main_eval['%s_%s' % (prefix, k)] = new_eval[k] main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for i, qid in enumerate(qid_list):
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
has_ans_score, has_ans_cnt = 0, 0
for qid in qid_list:
if not qid_to_has_ans[qid]:
continue
has_ans_cnt += 1
if qid not in scores:
continue
has_ans_score += scores[qid]
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(
preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(
preds, f1_raw, na_probs, qid_to_has_ans)
main_eval['best_exact'] = best_exact
main_eval['best_exact_thresh'] = exact_thresh
main_eval['best_f1'] = best_f1
main_eval['best_f1_thresh'] = f1_thresh
main_eval['has_ans_exact'] = has_ans_exact
main_eval['has_ans_f1'] = has_ans_f1
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans cur_score = num_no_ans
...@@ -318,10 +365,20 @@ def _compute_softmax(scores): ...@@ -318,10 +365,20 @@ def _compute_softmax(scores):
return probs return probs
def compute_predictions(all_examples, all_features, all_results, n_best_size, def compute_predictions_logits(
max_answer_length, do_lower_case, output_prediction_file, all_examples,
output_nbest_file, output_null_log_odds_file, verbose_logging, all_features,
version_2_with_negative, null_score_diff_threshold): all_results,
n_best_size,
max_answer_length,
do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
verbose_logging,
version_2_with_negative,
null_score_diff_threshold
):
"""Write final predictions to the json file and log-odds of null if needed.""" """Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file)) logger.info("Writing nbest to: %s" % (output_nbest_file))
...@@ -453,7 +510,7 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -453,7 +510,7 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
# In very rare edge cases we could only have single null prediction. # In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure. # So we just create a nonce prediction in this case to avoid failure.
if len(nbest)==1: if len(nbest) == 1:
nbest.insert(0, nbest.insert(0,
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
...@@ -512,12 +569,22 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -512,12 +569,22 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
return all_predictions return all_predictions
def compute_predictions_extended(all_examples, all_features, all_results, n_best_size, def compute_predictions_log_probs(
max_answer_length, output_prediction_file, all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
output_prediction_file,
output_nbest_file, output_nbest_file,
output_null_log_odds_file, orig_data_file, output_null_log_odds_file,
start_n_top, end_n_top, version_2_with_negative, orig_data_file,
tokenizer, verbose_logging): start_n_top,
end_n_top,
version_2_with_negative,
tokenizer,
verbose_logging
):
""" XLNet write prediction logic (more complex than Bert's). """ XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed. Write final predictions to the json file and log-odds of null if needed.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment