Commit 1b35f6b9 authored by jeffhsu3's avatar jeffhsu3
Browse files

pubmedqa

parent 5b6182d5
......@@ -17,6 +17,7 @@ from . import lambada
from . import race
from . import piqa
from . import triviaqa
from . import pubmedqa
TASK_REGISTRY = {
......@@ -45,6 +46,8 @@ TASK_REGISTRY = {
"lambada": lambada.LAMBADA,
"piqa": piqa.PiQA,
"pubmedqa" : pubmedqa.Pubmed_QA,
#"triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
......
"""
"""
import numpy as np
from ..utils import sh
from . common import HFTask, yesno
from lm_eval.base import Dataset, rf, mean
class Pubmed_QA(HFTask):
DATASET_PATH = "pubmed_qa"
DATASET_NAME = "pqa_labeled"
def has_training_docs(self):
return True
def has_test_docs(self):
return False
def has_validation_docs(self):
return False
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
return ""
def doc_to_text(self, doc):
ctxs = "\n".join(doc['context']['contexts'])
return "abstract: {}\nquestion: {}\nanswer:".format(
ctxs,
doc['question'],
doc['final_decision']
)
def doc_to_target(self, doc):
return " {}".format(doc['final_decision'])
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns
an iterable of Requests which will be sent to the LM.
"""
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
ll_maybe, _ = rf.loglikelihood(ctx, " maybe")
return ll_yes, ll_no, ll_maybe
def process_results(self, doc, results):
gold = doc['final_decision']
ll_yes, ll_no, ll_maybe = results
pred = np.argmax(results)
return {
'acc': ["yes", "no", "maybe"][pred] == gold,
}
def aggregation(self):
return {
'acc' : mean
}
def higher_is_better(self):
return {
'acc' : True
}
......@@ -32,9 +32,15 @@ def main():
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
# TODO: fall back to test docs
task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]
task_dict_items = []
for name, task in task_dict.items():
if task.has_validation_docs():
task_dict_items.append((name, task, 'validation'))
elif task.has_test_docs():
task_dict_items.append((name, task, 'test'))
elif task.has_training_docs():
task_dict_items.append((name, task, 'training'))
results = collections.defaultdict(dict)
requests = collections.defaultdict(list)
......@@ -49,8 +55,15 @@ def main():
docs = {}
# get lists of each type of requeste
for task_name, task in task_dict_items:
for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)):
for task_name, task, dset in task_dict_items:
if dset == 'training':
temp = task.training_docs()
elif dset == 'test':
temp = task.test_docs()
else:
temp = task.validation_docs()
for doc_id, doc in enumerate(itertools.islice(temp, 0, args.limit)):
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
......
......@@ -37,14 +37,14 @@ def main():
iters = []
for set in args.sets.split(","):
if set == 'train' and task.has_train_docs():
docs = task.train_docs()
if set == 'train' and task.has_training_docs():
docs = task.training_docs()
if set == 'val' and task.has_validation_docs():
docs = task.validation_docs()
if set == 'test' and task.has_test_docs():
docs = task.test_docs()
iters.append(docs)
docs = join_iters(iters)
with open(os.path.join(args.output_base_path, task_name), "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment