Commit 25160730 authored by Saurabh Saxena's avatar Saurabh Saxena Committed by A. Unique TensorFlower
Browse files

Handle corner case in squad_lib when best_non_null_entry is None.

PiperOrigin-RevId: 312193729
parent 3ef7bbcf
......@@ -36,6 +36,17 @@ class SquadExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
Attributes:
qas_id: ID of the question-answer pair.
question_text: Original text for the question.
doc_tokens: The list of tokens in the context obtained by splitting
on whitespace only.
orig_answer_text: Original text for the answer.
start_position: Starting index of the answer in `doc_tokens`.
end_position: Ending index of the answer in `doc_tokens`.
is_impossible: Whether the question is impossible to answer given the
context. Only used in SQuAD 2.0.
"""
def __init__(self,
......@@ -695,6 +706,7 @@ def postprocess_output(all_examples,
else:
# pytype: disable=attribute-error
# predict "" iff the null score - the score of best non-null > threshold
if best_non_null_entry is not None:
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
......@@ -702,6 +714,10 @@ def postprocess_output(all_examples,
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
else:
logging.warning("best_non_null_entry is None")
scores_diff_json[example.qas_id] = score_null
all_predictions[example.qas_id] = ""
# pytype: enable=attribute-error
all_nbest_json[example.qas_id] = nbest_json
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment