Commit f555a583 authored by lintangsutawika's avatar lintangsutawika
Browse files

fix formatting

parent 64d4600c
......@@ -2,6 +2,7 @@ from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
......@@ -41,14 +42,13 @@ def em(gold_list, pred):
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
return em_sum / max(1, len(gold_list))
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
......@@ -58,9 +58,7 @@ def compute_scores(gold_list, pred):
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
......@@ -71,6 +69,7 @@ def compute_scores(gold_list, pred):
"f1": f1_sum / max(1, len(gold_list)),
}
def process_results(doc, results):
gold_list = doc_to_target(doc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment