Commit c0862026 authored by thefazzer's avatar thefazzer
Browse files

Text & target impl, support fns, refactoring

parent 6738b241
...@@ -2,25 +2,21 @@ ...@@ -2,25 +2,21 @@
import json import json
import random import random
from lm_eval.base import Dataset import numpy as np
from lm_eval.base import Dataset, rf, mean
from ..utils import sh from ..utils import sh
import itertools from itertools import zip_longest
class CoQA(Dataset): class CoQA(Dataset):
def __init__(self):
self.download()
def download(self): def download(self):
pass
# -N only overwrites if the remote file has changed
sh (""" sh ("""
mkdir -p data/coqa mkdir -p data/coqa
wget --no-clobber http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json wget -N http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget --no-clobber http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json wget -N http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
""") """)
@classmethod
def get_answers(cls, doc, turn_id):
answers = zip(doc["answers"], zip(doc["additional_answers"]))
return answers[turn_id - 1]
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -43,12 +39,34 @@ class CoQA(Dataset): ...@@ -43,12 +39,34 @@ class CoQA(Dataset):
return "Given a passage and a conversation so far, answer the next question in the conversation." return "Given a passage and a conversation so far, answer the next question in the conversation."
def doc_to_text(self, doc): def doc_to_text(self, doc):
qa_pairs = [(q, a) in zip_longest(doc["questions"], doc["answers"][:-1])] # truncate target answer doc_text = doc["story"] + '\n\n'
return "{}\n\n{}".format(doc["story"], f"Q: {q}"+ '\n\n' + f"A: {a}") for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer
question = f"Q: {q['input_text']}" + '\n\n'
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:\n\n"
doc_text += question + answer
return doc_text
def doc_to_target(self, doc): @classmethod
# TODO: all distinct answers taking into account whitespace? def get_answers(cls, doc, turn_id):
return get_answers(doc, len(doc["questions"])) # get answers and valid alternatives
answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"]
answers.append(answer_forturn)
additionals = doc.get("additional_answers")
if additionals:
for key in additionals:
additional_answer_for_turn = additionals[key][turn_id - 1]["input_text"]
if additional_answer_for_turn.upper() not in map(str.upper, answers):
answers.append(additional_answer_for_turn)
return answers
def doc_to_target(self, doc, turnid=None):
# default to predict last turn
if turnid is None:
turnid = len(doc["questions"])
all_answers = self.get_answers(doc, turnid)
return all_answers[0] # ignore alternative answers for now
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -61,11 +79,10 @@ class CoQA(Dataset): ...@@ -61,11 +79,10 @@ class CoQA(Dataset):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
ll_alternative_answers = [ requests = []
rf.loglikelihood(ctx, " " + answer) for answer in get_answers(doc, len(doc["questions"])) for answer in self.get_answers(doc, len(doc["questions"])):
] requests.append(rf.loglikelihood(ctx, " " + answer))
return requests
return ll_alternative_answers
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -77,11 +94,10 @@ class CoQA(Dataset): ...@@ -77,11 +94,10 @@ class CoQA(Dataset):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
golds = get_answers(doc, len(doc["questions"])) gold = self.get_answers(doc, len(doc["questions"]))
pred = np.argmax(results) pred = np.argmax(results)
return { return {
"acc": pred in golds, "acc": int(pred == gold)
# "f1": (golds, pred), # TODO: Fix
} }
def aggregation(self): def aggregation(self):
...@@ -90,8 +106,9 @@ class CoQA(Dataset): ...@@ -90,8 +106,9 @@ class CoQA(Dataset):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -99,5 +116,6 @@ class CoQA(Dataset): ...@@ -99,5 +116,6 @@ class CoQA(Dataset):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment