Commit efa810f0 authored by thefazzer's avatar thefazzer
Browse files

Score computation, use squad metrics

parent 5552c8dc
......@@ -6,6 +6,7 @@ import numpy as np
from lm_eval.base import Dataset, rf, mean
from ..utils import sh
from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
class CoQA(Dataset):
def download(self):
......@@ -39,16 +40,18 @@ class CoQA(Dataset):
return "Given a passage and a conversation so far, answer the next question in the conversation."
def doc_to_text(self, doc):
# Each "doc" is a story and conversation (Q and A pairs).
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer
question = f"Q: {q['input_text']}" + '\n\n'
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:\n\n"
doc_text += question + answer
print(doc_text)
return doc_text
@classmethod
def get_answers(cls, doc, turn_id):
# get answers and valid alternatives
# This function returns an answer and valid alternatives.
answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"]
answers.append(answer_forturn)
......@@ -62,12 +65,27 @@ class CoQA(Dataset):
return answers
def doc_to_target(self, doc, turnid=None):
# default to predict last turn
# Default to predict last turn.
if turnid is None:
turnid = len(doc["questions"])
all_answers = self.get_answers(doc, turnid)
return all_answers[0] # ignore alternative answers for now
@staticmethod
def compute_scores(gold_list, pred):
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -80,8 +98,9 @@ class CoQA(Dataset):
part of the document for `doc`.
"""
requests = []
for answer in self.get_answers(doc, len(doc["questions"])):
requests.append(rf.loglikelihood(ctx, " " + answer))
for answers in self.get_answers(doc, len(doc["questions"])):
for a in answers:
requests.append(rf.loglikelihood(ctx, " " + a))
return requests
def process_results(self, doc, results):
......@@ -94,28 +113,26 @@ class CoQA(Dataset):
:param results:
The results of the requests created in construct_requests.
"""
gold = self.get_answers(doc, len(doc["questions"]))
turn_id = len(doc["questions"])
gold_list = self.get_answers(doc, turn_id)
pred = np.argmax(results)
(em, f1) = self.compute_scores(gold_list, pred)
return {
"acc": int(pred == gold)
"f1": f1,
"em": em,
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
def higher_is_better(self):
return {
"acc": mean
"f1": True,
"em": True,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
def aggregation(self):
return {
"acc": True
"f1": mean,
"em": mean,
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment