Commit 79c9b68a authored by thefazzer's avatar thefazzer
Browse files

Merge remote-tracking branch 'origin/master' into fazz/refactor-task-coqa

parents b9b3159b 1b467c57
...@@ -18,6 +18,8 @@ from . import lambada ...@@ -18,6 +18,8 @@ from . import lambada
from . import race from . import race
from . import piqa from . import piqa
from . import triviaqa from . import triviaqa
from . import pubmedqa
from . import sciq
from . import webqs from . import webqs
...@@ -47,6 +49,9 @@ TASK_REGISTRY = { ...@@ -47,6 +49,9 @@ TASK_REGISTRY = {
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ,
#"triviaqa": triviaqa.TriviaQA, #"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
......
import json import numpy as np
import random from lm_eval.base import rf, mean
from lm_eval.base import Task, rf, mean from . common import HFTask
from ..utils import sh
import os
class PiQA(Task):
def download(self): class PiQA(HFTask):
if not os.path.exists('data/piqa'): DATASET_PATH = "piqa"
#TODO: use best_download DATASET_NAME = None
sh("""
mkdir -p data/piqa
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
wget https://yonatanbisk.com/piqa/data/train-labels.lst -O data/piqa/piqa-train-labels.lst
wget https://yonatanbisk.com/piqa/data/valid.jsonl -O data/piqa/piqa-valid.jsonl
wget https://yonatanbisk.com/piqa/data/valid-labels.lst -O data/piqa/piqa-valid-labels.lst
wget https://yonatanbisk.com/piqa/data/tests.jsonl -O data/piqa/piqa-test.jsonl
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -26,44 +16,25 @@ class PiQA(Task): ...@@ -26,44 +16,25 @@ class PiQA(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def load_docs(self, textfilename, labelfilename):
if labelfilename != None:
return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(map(lambda x: x.strip(), open(labelfilename, 'r'))))
else:
return [json.loads(entry) for entry in list(open(textfilename,'r'))]
def training_docs(self):
return self.load_docs('data/piqa/piqa-train.jsonl', 'data/piqa/piqa-train-labels.lst')
def validation_docs(self):
return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst')
#def test_docs(self):
# return self.load_docs('data/piqa/piqa-test.jsonl', None)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc[0]['goal'] + "\n" return doc["goal"] + "\n"
def doc_to_target(self, doc): def doc_to_target(self, doc):
#TODO: check if oa uses newline solutions = [doc["sol1"], doc["sol2"]]
rightanswer = int(doc[1]) + 1 return solutions[doc["label"]]
return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1']) ll_1, _ = rf.loglikelihood(ctx, doc['sol1'])
ll_2, _ = rf.loglikelihood(ctx, doc[0]['sol2']) ll_2, _ = rf.loglikelihood(ctx, doc['sol2'])
return ll_1, ll_2 return ll_1, ll_2
def process_results(self, doc, results): def process_results(self, doc, results):
ll_1, ll_2 = results
return { return {
'acc': (ll_1 > ll_2) == (int(doc[1]) == 0) 'acc': np.argmax(results) == doc["label"]
} }
def aggregation(self): def aggregation(self):
......
import numpy as np
import json
import random
from .common import HFTask
from lm_eval.base import rf, mean
class Pubmed_QA(HFTask):
DATASET_PATH = "pubmed_qa"
DATASET_NAME = "pqa_labeled"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
if self.has_test_docs():
# HF is labelled as train but its really just for testing
return self.data["train"]
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"])
return "abstract: {}\nquestion: {}\nanswer:".format(
ctxs,
doc["question"],
doc["final_decision"]
)
def doc_to_target(self, doc):
return " {}".format(doc["final_decision"])
def fewshot_examples(self, k):
# Since only test docs sample from test docs
if self._training_docs is None:
self._training_docs = list(self.test_docs())
return random.sample(self._training_docs, k)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns
an iterable of Requests which will be sent to the LM.
"""
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
ll_maybe, _ = rf.loglikelihood(ctx, " maybe")
return ll_yes, ll_no, ll_maybe
def process_results(self, doc, results):
gold = doc["final_decision"]
ll_yes, ll_no, ll_maybe = results
pred = np.argmax(results)
return {
"acc": ["yes", "no", "maybe"][pred] == gold,
}
def aggregation(self):
return {
"acc" : mean
}
def higher_is_better(self):
return {
"acc" : True
}
import os
import json
from ..utils import sh
from lm_eval.base import MultipleChoiceTask, rf, mean
import zipfile
class SciQ(MultipleChoiceTask):
# Multiple languages and multiple years
def download(self):
if not os.path.exists('data/sciq'):
os.mkdir('data/sciq')
sh((
"wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
))
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def _convert_standard(self, doc):
choices = [
doc["distractor1"],
doc["distractor2"],
doc["distractor3"],
doc["correct_answer"],
]
src = doc['support']
out_doc = {
"source" : src,
"query" : doc['question'],
"choices" : choices,
"gold" : 3,
}
return out_doc
def load_docs(self, textfilename):
with open(textfilename, 'r') as j:
docs = json.loads(j.read())
for record in docs:
yield self._convert_standard(record)
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
def validation_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
def test_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc):
return "{}\n{}".format(doc["source"], doc["query"])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment