Commit de276de1 authored by LysandreJik's avatar LysandreJik
Browse files

Working evaluation

parent c835bc85
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet).""" """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate
import argparse import argparse
import logging import logging
...@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval() model.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {
'input_ids': batch[0],
'attention_mask': batch[1] 'attention_mask': batch[1]
} }
if args.model_type != 'distilbert': if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3] example_indices = batch[3]
...@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""):
for i, example_index in enumerate(example_indices): for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()] eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id) unique_id = int(eval_feature.unique_id)
if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
result = RawResultExtended(unique_id = unique_id,
start_top_log_probs = to_list(outputs[0][i]),
start_top_index = to_list(outputs[1][i]),
end_top_log_probs = to_list(outputs[2][i]),
end_top_index = to_list(outputs[3][i]),
cls_logits = to_list(outputs[4][i]))
else:
result = RawResult(unique_id = unique_id,
start_logits = to_list(outputs[0][i]),
end_logits = to_list(outputs[1][i]))
all_results.append(result) all_results.append(result)
evalTime = timeit.default_timer() - start_time evalTime = timeit.default_timer() - start_time
...@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
write_predictions_extended(examples, features, all_results, args.n_best_size, predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file, args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file, output_nbest_file, output_null_log_odds_file, args.predict_file,
model.config.start_n_top, model.config.end_n_top, model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.version_2_with_negative, tokenizer, args.verbose_logging)
else: else:
write_predictions(examples, features, all_results, args.n_best_size, predictions = compute_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file, args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold) args.version_2_with_negative, args.null_score_diff_threshold)
# Evaluate with the official SQuAD script results = squad_evaluate(examples, predictions)
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
pred_file=output_prediction_file,
na_prob_file=output_null_log_odds_file)
results = evaluate_on_squad(evaluate_options)
return results return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
...@@ -306,7 +295,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -306,7 +295,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
processor = SquadV2Processor() processor = SquadV2Processor()
examples = processor.get_dev_examples("examples/squad") if evaluate else processor.get_train_examples("examples/squad") examples = processor.get_dev_examples("examples/squad", only_first=100) if evaluate else processor.get_train_examples("examples/squad")
# import tensorflow_datasets as tfds
# tfds_examples = tfds.load("squad")
# examples = SquadV1Processor().get_examples_from_dataset(tfds_examples["validation"])
features = squad_convert_examples_to_features( features = squad_convert_examples_to_features(
examples=examples, examples=examples,
tokenizer=tokenizer, tokenizer=tokenizer,
......
""" Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was
modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import json import json
import logging import logging
import math import math
import collections import collections
from io import open from io import open
from tqdm import tqdm from tqdm import tqdm
import string
import re
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def get_raw_scores(examples, preds):
"""
Computes the exact and f1 scores from the examples and the model predictions
"""
exact_scores = {}
f1_scores = {}
for example in examples:
qas_id = example.qas_id
gold_answers = [answer['text'] for answer in example.answers if normalize_answer(answer['text'])]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = ['']
if qas_id not in preds:
print('Missing prediction for %s' % qas_id)
continue
prediction = preds[qas_id]
exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
return exact_scores, f1_scores
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores.values()) / total),
('f1', 100.0 * sum(f1_scores.values()) / total),
('total', total),
])
else:
total = len(qid_list)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
('total', total),
])
def merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for _, qid in enumerate(qid_list):
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval['best_exact'] = best_exact
main_eval['best_exact_thresh'] = exact_thresh
main_eval['best_f1'] = best_f1
main_eval['best_f1_thresh'] = f1_thresh
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
if no_answer_probs is None:
no_answer_probs = {k: 0.0 for k in preds}
exact, f1 = get_raw_scores(examples, preds)
exact_threshold = apply_no_ans_threshold(exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
evaluation = make_eval_dict(exact_threshold, f1_threshold)
if has_answer_qids:
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
merge_eval(evaluation, has_ans_eval, 'HasAns')
if no_answer_qids:
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
merge_eval(evaluation, no_ans_eval, 'NoAns')
if no_answer_probs:
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
return evaluation
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heuristic between
# `pred_text` and `orig_text` to get a character-to-character alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
if verbose_logging:
logger.info(
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in tok_ns_to_s_map.items():
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
if verbose_logging:
logger.info("Couldn't map start position")
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
if verbose_logging:
logger.info("Couldn't map end position")
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def _get_best_indexes(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
for i in range(len(index_and_score)):
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
def compute_predictions(all_examples, all_features, all_results, n_best_size, def compute_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file, max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, verbose_logging, output_nbest_file, output_null_log_odds_file, verbose_logging,
...@@ -204,132 +512,192 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -204,132 +512,192 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
return all_predictions return all_predictions
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): def compute_predictions_extended(all_examples, all_features, all_results, n_best_size,
"""Project the tokenized prediction back to the original text.""" max_answer_length, output_prediction_file,
output_nbest_file,
output_null_log_odds_file, orig_data_file,
start_n_top, end_n_top, version_2_with_negative,
tokenizer, verbose_logging):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
# When we created the data, we kept track of the alignment between original Requires utils_squad_evaluate.py
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So """
# now `orig_text` contains the span of our original text corresponding to the _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
# span that we predicted. "PrelimPrediction",
# ["feature_index", "start_index", "end_index",
# However, `orig_text` may contain extra characters that we don't want in "start_log_prob", "end_log_prob"])
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heuristic between
# `pred_text` and `orig_text` to get a character-to-character alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text): _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
ns_chars = [] "NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result logger.info("Writing predictions to: %s", output_prediction_file)
# and `pred_text`, and check if they are the same length. If they are # logger.info("Writing nbest to: %s" % (output_nbest_file))
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text)) example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
start_position = tok_text.find(pred_text) unique_id_to_result = {}
if start_position == -1: for result in all_results:
if verbose_logging: unique_id_to_result[result.unique_id] = result
logger.info(
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) all_predictions = collections.OrderedDict()
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
if len(orig_ns_text) != len(tok_ns_text): for (example_index, example) in enumerate(all_examples):
if verbose_logging: features = example_index_to_features[example_index]
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using prelim_predictions = []
# the character-to-character alignment. # keep track of the minimum score of null start+end of position 0
tok_s_to_ns_map = {} score_null = 1000000 # large and positive
for (i, tok_index) in tok_ns_to_s_map.items():
tok_s_to_ns_map[tok_index] = i
orig_start_position = None for (feature_index, feature) in enumerate(features):
if start_position in tok_s_to_ns_map: result = unique_id_to_result[feature.unique_id]
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None: cur_null_score = result.cls_logits
if verbose_logging:
logger.info("Couldn't map start position")
return orig_text
orig_end_position = None # if we could have irrelevant answers, get the min score of irrelevant
if end_position in tok_s_to_ns_map: score_null = min(score_null, cur_null_score)
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None: for i in range(start_n_top):
if verbose_logging: for j in range(end_n_top):
logger.info("Couldn't map end position") start_log_prob = result.start_top_log_probs[i]
return orig_text start_index = result.start_top_index[i]
output_text = orig_text[orig_start_position:(orig_end_position + 1)] j_index = i * end_n_top + j
return output_text
end_log_prob = result.end_top_log_probs[j_index]
end_index = result.end_top_index[j_index]
def _get_best_indexes(logits, n_best_size): # We could hypothetically create invalid predictions, e.g., predict
"""Get the n-best logits from a list.""" # that the start of the span is in the question. We throw out all
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) # invalid predictions.
if start_index >= feature.paragraph_len - 1:
continue
if end_index >= feature.paragraph_len - 1:
continue
best_indexes = [] if not feature.token_is_max_context.get(start_index, False):
for i in range(len(index_and_score)): continue
if i >= n_best_size: if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_log_prob=start_log_prob,
end_log_prob=end_log_prob))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_log_prob + x.end_log_prob),
reverse=True)
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break break
best_indexes.append(index_and_score[i][0]) feature = features[pred.feature_index]
return best_indexes
# XLNet un-tokenizer
# Let's keep it simple for now and see if we need all this later.
#
# tok_start_to_orig_index = feature.tok_start_to_orig_index
# tok_end_to_orig_index = feature.tok_end_to_orig_index
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
# paragraph_text = example.paragraph_text
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
def _compute_softmax(scores): # Clean whitespace
"""Compute softmax probability over raw logits.""" tok_text = tok_text.strip()
if not scores: tok_text = " ".join(tok_text.split())
return [] orig_text = " ".join(orig_tokens)
max_score = None final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
for score in scores: verbose_logging)
if max_score is None or score > max_score:
max_score = score
exp_scores = [] if final_text in seen_predictions:
total_sum = 0.0 continue
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = [] seen_predictions[final_text] = True
for score in exp_scores:
probs.append(score / total_sum) nbest.append(
return probs _NbestPrediction(
text=final_text,
start_log_prob=pred.start_log_prob,
end_log_prob=pred.end_log_prob))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="", start_log_prob=-1e6,
end_log_prob=-1e6))
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_log_prob + entry.end_log_prob)
if not best_non_null_entry:
best_non_null_entry = entry
probs = _compute_softmax(total_scores)
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_log_prob"] = entry.start_log_prob
output["end_log_prob"] = entry.end_log_prob
nbest_json.append(output)
assert len(nbest_json) >= 1
assert best_non_null_entry is not None
score_diff = score_null
scores_diff_json[example.qas_id] = score_diff
# note(zhiliny): always predict best_non_null_entry
# and the evaluation script will search for the best threshold
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
with open(orig_data_file, "r", encoding='utf-8') as reader:
orig_data = json.load(reader)["data"]
qid_to_has_ans = make_qid_to_has_ans(orig_data)
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
out_eval = {}
find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, scores_diff_json, qid_to_has_ans)
return out_eval
...@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor): ...@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor):
else: else:
is_impossible = False is_impossible = False
if not is_impossible and is_training: if not is_impossible:
if (len(qa["answers"]) != 1): if is_training:
raise ValueError(
"For training, each question should have exactly 1 answer.")
answer = qa["answers"][0] answer = qa["answers"][0]
answer_text = answer['text'] answer_text = answer['text']
start_position_character = answer['answer_start'] start_position_character = answer['answer_start']
else:
answers = qa["answers"]
example = SquadExample( example = SquadExample(
qas_id=qas_id, qas_id=qas_id,
...@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor): ...@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor):
answer_text=answer_text, answer_text=answer_text,
start_position_character=start_position_character, start_position_character=start_position_character,
title=title, title=title,
is_impossible=is_impossible is_impossible=is_impossible,
answers=answers
) )
examples.append(example) examples.append(example)
...@@ -352,6 +353,7 @@ class SquadExample(object): ...@@ -352,6 +353,7 @@ class SquadExample(object):
answer_text, answer_text,
start_position_character, start_position_character,
title, title,
answers=None,
is_impossible=False): is_impossible=False):
self.qas_id = qas_id self.qas_id = qas_id
self.question_text = question_text self.question_text = question_text
...@@ -359,6 +361,7 @@ class SquadExample(object): ...@@ -359,6 +361,7 @@ class SquadExample(object):
self.answer_text = answer_text self.answer_text = answer_text
self.title = title self.title = title
self.is_impossible = is_impossible self.is_impossible = is_impossible
self.answers = answers
self.start_position, self.end_position = 0, 0 self.start_position, self.end_position = 0, 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment