Commit 285b1241 authored by LysandreJik's avatar LysandreJik
Browse files

Added SquadResult

parent 1e9ac5a7
......@@ -425,3 +425,74 @@ class SquadFeatures(object):
self.start_position = start_position
self.end_position = end_position
class SquadResult(object):
"""
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
Args:
result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by
XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the
following accepted formats:
`dict` output by a simple model:
{
"start_logits": int,
"end_logits": int,
"unique_id": string
}
`list` output by a simple model:
[start_logits, end_logits, unique_id]
`dict` output by a complex model:
{
"start_top_log_probs": float,
"start_top_index": int,
"end_top_log_probs": float,
"end_top_index": int,
"cls_logits": int,
"unique_id": string
}
`list` output by a complex model:
[start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id]
See `run_squad.py` for an example.
"""
def __init__(self, result):
if isinstance(result, dict):
if "start_logits" in result and "end_logits" in result:
self.start_logits = result["start_logits"]
self.end_logits = result["end_logits"]
elif "start_top_log_probs" in result and "start_top_index" in result:
self.start_top_log_probs = result["start_top_log_probs"]
self.start_top_index = result["start_top_index"]
self.end_top_log_probs = result["end_top_log_probs"]
self.end_top_index = result["end_top_index"]
self.cls_logits = result["cls_logits"]
else:
raise ValueError("SquadResult instantiated with wrong values.")
self.unique_id = result["unique_id"]
elif isinstance(result, list):
if len(result) == 3:
self.start_logits = result[0]
self.end_logits = result[1]
elif len(result) == 6:
self.start_top_log_probs = result[0]
self.start_top_index = result[1]
self.end_top_log_probs = result[2]
self.end_top_index = result[3]
self.cls_logits = result[4]
else:
raise ValueError("SquadResult instantiated with wrong values.")
self.unique_id = result[-1]
else:
raise ValueError("SquadResult instantiated with wrong values. Should be a dictionary or a list.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment