Commit 12880f1c authored by thefazzer's avatar thefazzer
Browse files

Initial skeleton refactoring

parent 66558b35
...@@ -4,19 +4,23 @@ import json ...@@ -4,19 +4,23 @@ import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Dataset
from ..utils import sh from ..utils import sh
import itertools
class CoQA(Dataset): class CoQA(Dataset):
def __init__(self): def __init__(self):
self.download() self.download()
def download(self): def download(self):
#TODO: don't download if files already there sh ("""
sh("""
mkdir -p data/coqa mkdir -p data/coqa
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json wget --no-clobber http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json wget --no-clobber http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
""") """)
@classmethod
def get_answers(cls, doc, turn_id):
answers = zip(doc["answers"], zip(doc["additional_answers"]))
return answers[turn_id - 1]
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -36,16 +40,15 @@ class CoQA(Dataset): ...@@ -36,16 +40,15 @@ class CoQA(Dataset):
pass pass
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out description return "Given a passage and a conversation so far, answer the next question in the conversation."
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
# TODO: implement. qa_pairs = [(q, a) in zip_longest(doc["questions"], doc["answers"][:-1])] # truncate target answer
raise NotImplementedError('doc_to_text not implemented') return "{}\n\n{}".format(doc["story"], f"Q: {q}"+ '\n\n' + f"A: {a}")
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: implement. # TODO: all distinct answers taking into account whitespace?
raise NotImplementedError('doc_to_target not implemented') return get_answers(doc, len(doc["questions"]))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -58,8 +61,11 @@ class CoQA(Dataset): ...@@ -58,8 +61,11 @@ class CoQA(Dataset):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. ll_alternative_answers = [
raise NotImplementedError('Evaluation not implemented') rf.loglikelihood(ctx, " " + answer) for answer in get_answers(doc, len(doc["questions"]))
]
return ll_alternative_answers
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -71,8 +77,12 @@ class CoQA(Dataset): ...@@ -71,8 +77,12 @@ class CoQA(Dataset):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. golds = get_answers(doc, len(doc["questions"]))
raise NotImplementedError('Evaluation not implemented') pred = np.argmax(results)
return {
"acc": pred in golds,
# "f1": (golds, pred), # TODO: Fix
}
def aggregation(self): def aggregation(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment