Commit f555a583 authored by lintangsutawika's avatar lintangsutawika
Browse files

fix formatting

parent 64d4600c
...@@ -2,6 +2,7 @@ from itertools import zip_longest ...@@ -2,6 +2,7 @@ from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
def doc_to_text(doc): def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
...@@ -41,14 +42,13 @@ def em(gold_list, pred): ...@@ -41,14 +42,13 @@ def em(gold_list, pred):
for i in range(len(gold_list)): for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :] gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum # predictions compared against (n) golds and take maximum
em_sum += max( em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
else: else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
return em_sum / max(1, len(gold_list)) return em_sum / max(1, len(gold_list))
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact) # tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1) # test for overlap (compute_f1)
...@@ -58,9 +58,7 @@ def compute_scores(gold_list, pred): ...@@ -58,9 +58,7 @@ def compute_scores(gold_list, pred):
for i in range(len(gold_list)): for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :] gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum # predictions compared against (n) golds and take maximum
em_sum += max( em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else: else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
...@@ -71,6 +69,7 @@ def compute_scores(gold_list, pred): ...@@ -71,6 +69,7 @@ def compute_scores(gold_list, pred):
"f1": f1_sum / max(1, len(gold_list)), "f1": f1_sum / max(1, len(gold_list)),
} }
def process_results(doc, results): def process_results(doc, results):
gold_list = doc_to_target(doc) gold_list = doc_to_target(doc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment