"include/vscode:/vscode.git/clone" did not exist on "716860e37dc6c3286610772c8b942b86dc809d6c"
Commit cca75e78 authored by LysandreJik's avatar LysandreJik
Browse files

Kill the demon spawn

parent bf119c05
......@@ -248,7 +248,28 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
output = [to_list(output[i]) for output in outputs]
if len(output) >= 5:
start_logits = output[0]
start_top_index = output[1]
end_logits = output[2]
end_top_index = output[3],
cls_logits = output[4]
result = SquadResult(
unique_id, start_logits, end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits
)
else:
start_logits, end_logits = output
result = SquadResult(
unique_id, start_logits, end_logits
)
all_results.append(result)
evalTime = timeit.default_timer() - start_time
......
......@@ -446,72 +446,21 @@ class SquadFeatures(object):
self.end_position = end_position
class SquadResult(object):
"""
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
Args:
result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by
XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the
following accepted formats:
`dict` output by a simple model:
{
"start_logits": int,
"end_logits": int,
"unique_id": string
}
`list` output by a simple model:
[start_logits, end_logits, unique_id]
`dict` output by a complex model:
{
"start_top_log_probs": float,
"start_top_index": int,
"end_top_log_probs": float,
"end_top_index": int,
"cls_logits": int,
"unique_id": string
}
`list` output by a complex model:
[start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id]
See `run_squad.py` for an example.
unique_id: The unique identifier corresponding to that example.
start_logits: The logits corresponding to the start of the answer
end_logits: The logits corresponding to the end of the answer
"""
def __init__(self, result):
if isinstance(result, dict):
if "start_logits" in result and "end_logits" in result:
self.start_logits = result["start_logits"]
self.end_logits = result["end_logits"]
elif "start_top_log_probs" in result and "start_top_index" in result:
self.start_top_log_probs = result["start_top_log_probs"]
self.start_top_index = result["start_top_index"]
self.end_top_log_probs = result["end_top_log_probs"]
self.end_top_index = result["end_top_index"]
self.cls_logits = result["cls_logits"]
else:
raise ValueError("SquadResult instantiated with wrong values.")
self.unique_id = result["unique_id"]
elif isinstance(result, list):
if len(result) == 3:
self.start_logits = result[0]
self.end_logits = result[1]
elif len(result) == 6:
self.start_top_log_probs = result[0]
self.start_top_index = result[1]
self.end_top_log_probs = result[2]
self.end_top_index = result[3]
self.cls_logits = result[4]
else:
raise ValueError("SquadResult instantiated with wrong values.")
self.unique_id = result[-1]
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
self.start_top_log_probs = start_logits
self.end_top_log_probs = end_logits
self.unique_id = unique_id
else:
raise ValueError("SquadResult instantiated with wrong values. Should be a dictionary or a list.")
if start_top_index:
self.start_top_index = start_top_index
self.end_top_index = end_top_index
self.cls_logits = cls_logits
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment