Commit cca75e78 authored by LysandreJik's avatar LysandreJik
Browse files

Kill the demon spawn

parent bf119c05
......@@ -248,7 +248,28 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
output = [to_list(output[i]) for output in outputs]
if len(output) >= 5:
start_logits = output[0]
start_top_index = output[1]
end_logits = output[2]
end_top_index = output[3],
cls_logits = output[4]
result = SquadResult(
unique_id, start_logits, end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits
)
else:
start_logits, end_logits = output
result = SquadResult(
unique_id, start_logits, end_logits
)
all_results.append(result)
evalTime = timeit.default_timer() - start_time
......
......@@ -446,72 +446,21 @@ class SquadFeatures(object):
self.end_position = end_position
class SquadResult(object):
"""
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
Args:
result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by
XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the
following accepted formats:
`dict` output by a simple model:
{
"start_logits": int,
"end_logits": int,
"unique_id": string
}
`list` output by a simple model:
[start_logits, end_logits, unique_id]
`dict` output by a complex model:
{
"start_top_log_probs": float,
"start_top_index": int,
"end_top_log_probs": float,
"end_top_index": int,
"cls_logits": int,
"unique_id": string
}
`list` output by a complex model:
[start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id]
See `run_squad.py` for an example.
unique_id: The unique identifier corresponding to that example.
start_logits: The logits corresponding to the start of the answer
end_logits: The logits corresponding to the end of the answer
"""
def __init__(self, result):
if isinstance(result, dict):
if "start_logits" in result and "end_logits" in result:
self.start_logits = result["start_logits"]
self.end_logits = result["end_logits"]
elif "start_top_log_probs" in result and "start_top_index" in result:
self.start_top_log_probs = result["start_top_log_probs"]
self.start_top_index = result["start_top_index"]
self.end_top_log_probs = result["end_top_log_probs"]
self.end_top_index = result["end_top_index"]
self.cls_logits = result["cls_logits"]
else:
raise ValueError("SquadResult instantiated with wrong values.")
self.unique_id = result["unique_id"]
elif isinstance(result, list):
if len(result) == 3:
self.start_logits = result[0]
self.end_logits = result[1]
elif len(result) == 6:
self.start_top_log_probs = result[0]
self.start_top_index = result[1]
self.end_top_log_probs = result[2]
self.end_top_index = result[3]
self.cls_logits = result[4]
else:
raise ValueError("SquadResult instantiated with wrong values.")
self.unique_id = result[-1]
else:
raise ValueError("SquadResult instantiated with wrong values. Should be a dictionary or a list.")
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
self.start_top_log_probs = start_logits
self.end_top_log_probs = end_logits
self.unique_id = unique_id
if start_top_index:
self.start_top_index = start_top_index
self.end_top_index = end_top_index
self.cls_logits = cls_logits
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment