"testing/vscode:/vscode.git/clone" did not exist on "29051439dbed90583bfad1d16dfca88a95e78709"
Commit efa810f0 authored by thefazzer's avatar thefazzer
Browse files

Score computation, use squad metrics

parent 5552c8dc
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
from lm_eval.base import Dataset, rf, mean from lm_eval.base import Dataset, rf, mean
from ..utils import sh from ..utils import sh
from itertools import zip_longest from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
class CoQA(Dataset): class CoQA(Dataset):
def download(self): def download(self):
...@@ -39,16 +40,18 @@ class CoQA(Dataset): ...@@ -39,16 +40,18 @@ class CoQA(Dataset):
return "Given a passage and a conversation so far, answer the next question in the conversation." return "Given a passage and a conversation so far, answer the next question in the conversation."
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Each "doc" is a story and conversation (Q and A pairs).
doc_text = doc["story"] + '\n\n' doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer
question = f"Q: {q['input_text']}" + '\n\n' question = f"Q: {q['input_text']}" + '\n\n'
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:\n\n" answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:\n\n"
doc_text += question + answer doc_text += question + answer
print(doc_text)
return doc_text return doc_text
@classmethod @classmethod
def get_answers(cls, doc, turn_id): def get_answers(cls, doc, turn_id):
# get answers and valid alternatives # This function returns an answer and valid alternatives.
answers = [] answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"] answer_forturn = doc["answers"][turn_id - 1]["input_text"]
answers.append(answer_forturn) answers.append(answer_forturn)
...@@ -62,12 +65,27 @@ class CoQA(Dataset): ...@@ -62,12 +65,27 @@ class CoQA(Dataset):
return answers return answers
def doc_to_target(self, doc, turnid=None): def doc_to_target(self, doc, turnid=None):
# default to predict last turn # Default to predict last turn.
if turnid is None: if turnid is None:
turnid = len(doc["questions"]) turnid = len(doc["questions"])
all_answers = self.get_answers(doc, turnid) all_answers = self.get_answers(doc, turnid)
return all_answers[0] # ignore alternative answers for now return all_answers[0] # ignore alternative answers for now
@staticmethod
def compute_scores(gold_list, pred):
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -80,8 +98,9 @@ class CoQA(Dataset): ...@@ -80,8 +98,9 @@ class CoQA(Dataset):
part of the document for `doc`. part of the document for `doc`.
""" """
requests = [] requests = []
for answer in self.get_answers(doc, len(doc["questions"])): for answers in self.get_answers(doc, len(doc["questions"])):
requests.append(rf.loglikelihood(ctx, " " + answer)) for a in answers:
requests.append(rf.loglikelihood(ctx, " " + a))
return requests return requests
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -94,28 +113,26 @@ class CoQA(Dataset): ...@@ -94,28 +113,26 @@ class CoQA(Dataset):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
gold = self.get_answers(doc, len(doc["questions"]))
turn_id = len(doc["questions"])
gold_list = self.get_answers(doc, turn_id)
pred = np.argmax(results) pred = np.argmax(results)
(em, f1) = self.compute_scores(gold_list, pred)
return { return {
"acc": int(pred == gold) "f1": f1,
"em": em,
} }
def aggregation(self): def higher_is_better(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return { return {
"acc": mean "f1": True,
"em": True,
} }
def higher_is_better(self): def aggregation(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return { return {
"acc": True "f1": mean,
"em": mean,
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment