Commit 0265f59c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 303225001
parent 1bd89dac
...@@ -19,9 +19,10 @@ from __future__ import division ...@@ -19,9 +19,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os import os
import tempfile import tempfile
import time
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
...@@ -126,24 +127,22 @@ def main(_): ...@@ -126,24 +127,22 @@ def main(_):
if 'predict' in FLAGS.mode: if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode: if 'eval' in FLAGS.mode:
if input_meta_data.get('version_2_with_negative', False): eval_metrics = eval_squad(strategy, input_meta_data)
logging.error('SQuAD v2 eval is not supported. ' f1_score = eval_metrics['final_f1']
'Falling back to predict mode.') logging.info('SQuAD eval F1-score: %f', f1_score)
predict_squad(strategy, input_meta_data) if (not strategy) or strategy.extended.should_save_summary:
summary_dir = os.path.join(FLAGS.model_dir, 'summaries')
else: else:
eval_metrics = eval_squad(strategy, input_meta_data) summary_dir = tempfile.mkdtemp()
f1_score = eval_metrics['f1'] summary_writer = tf.summary.create_file_writer(
logging.info('SQuAD eval F1-score: %f', f1_score) os.path.join(summary_dir, 'eval'))
if (not strategy) or strategy.extended.should_save_summary: with summary_writer.as_default():
summary_dir = os.path.join(FLAGS.model_dir, 'summaries') # TODO(lehou): write to the correct step number.
else: tf.summary.scalar('F1-score', f1_score, step=0)
summary_dir = tempfile.mkdtemp() summary_writer.flush()
summary_writer = tf.summary.create_file_writer( # Wait for some time, for the depending mldash/tensorboard jobs to finish
os.path.join(summary_dir, 'eval')) # exporting the final F1-score.
with summary_writer.as_default(): time.sleep(60)
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -31,6 +31,7 @@ from official.nlp.bert import input_pipeline ...@@ -31,6 +31,7 @@ from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
from official.nlp.bert import model_training_utils from official.nlp.bert import model_training_utils
from official.nlp.bert import squad_evaluate_v1_1 from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -373,16 +374,16 @@ def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): ...@@ -373,16 +374,16 @@ def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False)) input_meta_data.get('version_2_with_negative', False))
with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
dataset_json = json.load(reader)
pred_dataset = dataset_json['data']
if input_meta_data.get('version_2_with_negative', False): if input_meta_data.get('version_2_with_negative', False):
# TODO(lehou): support in memory evaluation for SQuAD v2. eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset,
logging.error('SQuAD v2 eval is not supported. Skipping eval') all_predictions,
return None scores_diff_json)
else: else:
with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
dataset_json = json.load(reader)
pred_dataset = dataset_json['data']
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions) eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
return eval_metrics return eval_metrics
def export_squad(model_export_path, input_meta_data, bert_config): def export_squad(model_export_path, input_meta_data, bert_config):
......
...@@ -105,4 +105,4 @@ def evaluate(dataset, predictions): ...@@ -105,4 +105,4 @@ def evaluate(dataset, predictions):
exact_match = exact_match / total exact_match = exact_match / total
f1 = f1 / total f1 = f1 / total
return {"exact_match": exact_match, "f1": f1} return {"exact_match": exact_match, "final_f1": f1}
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation script for SQuAD version 2.0.
The functions are copied and modified from
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import string
from absl import logging
def _make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid_to_has_ans[qa['id']] = bool(qa['answers'])
return qid_to_has_ans
def _normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _get_tokens(s):
if not s: return []
return _normalize_answer(s).split()
def _compute_exact(a_gold, a_pred):
return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
def _compute_f1(a_gold, a_pred):
"""Compute F1-score."""
gold_toks = _get_tokens(a_gold)
pred_toks = _get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if not gold_toks or not pred_toks:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _get_raw_scores(dataset, predictions):
"""Compute raw scores."""
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid = qa['id']
gold_answers = [a['text'] for a in qa['answers']
if _normalize_answer(a['text'])]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = ['']
if qid not in predictions:
logging.error('Missing prediction for %s', qid)
continue
a_pred = predictions[qid]
# Take max over all gold answers
exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def _apply_no_ans_threshold(
scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
"""Make evaluation result dictionary."""
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores.values()) / total),
('f1', 100.0 * sum(f1_scores.values()) / total),
('total', total),
])
else:
total = len(qid_list)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
('total', total),
])
def _merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
"""Make evaluation dictionary containing average recision recall."""
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i+1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
return {'ap': 100.0 * avg_prec}
def _run_precision_recall_analysis(
main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
"""Run precision recall analysis and return result dictionary."""
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = _make_precision_recall_eval(
exact_raw, na_probs, num_true_pos, qid_to_has_ans)
pr_f1 = _make_precision_recall_eval(
f1_raw, na_probs, num_true_pos, qid_to_has_ans)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = _make_precision_recall_eval(
oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
_merge_eval(main_eval, pr_exact, 'pr_exact')
_merge_eval(main_eval, pr_f1, 'pr_f1')
_merge_eval(main_eval, pr_oracle, 'pr_oracle')
def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
"""Find the best threshold for no answer probability."""
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for qid in qid_list:
if qid not in scores: continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if predictions[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def _find_all_best_thresh(
main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = _find_best_thresh(
predictions, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = _find_best_thresh(
predictions, f1_raw, na_probs, qid_to_has_ans)
main_eval['final_exact'] = best_exact
main_eval['final_exact_thresh'] = exact_thresh
main_eval['final_f1'] = best_f1
main_eval['final_f1_thresh'] = f1_thresh
def evaluate(dataset, predictions, na_probs=None):
"""Evaluate prediction results."""
new_orig_data = []
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
if qa['id'] in predictions:
new_para = {'qas': [qa]}
new_article = {'paragraphs': [new_para]}
new_orig_data.append(new_article)
dataset = new_orig_data
if na_probs is None:
na_probs = {k: 0.0 for k in predictions}
qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
out_eval = _make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = _make_eval_dict(
exact_thresh, f1_thresh, qid_list=has_ans_qids)
_merge_eval(out_eval, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
_merge_eval(out_eval, no_ans_eval, 'NoAns')
_find_all_best_thresh(
out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
_run_precision_recall_analysis(
out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
return out_eval
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment