"src/instruction.cpp" did not exist on "d2d5fd1960ceaae0079000a4ee9711c42d2f5a62"
Commit 0971cffe authored by jeffhsu3's avatar jeffhsu3
Browse files

separated sciq and pubmedqa to different files

parent 620c0d26
...@@ -18,6 +18,7 @@ from . import race ...@@ -18,6 +18,7 @@ from . import race
from . import piqa from . import piqa
from . import triviaqa from . import triviaqa
from . import pubmedqa from . import pubmedqa
from . import sciq
from . import webqs from . import webqs
...@@ -48,7 +49,7 @@ TASK_REGISTRY = { ...@@ -48,7 +49,7 @@ TASK_REGISTRY = {
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : pubmedqa.SciQ, "sciq" : sciq.SciQ,
#"triviaqa": triviaqa.TriviaQA, #"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
......
"""
"""
import os
import numpy as np import numpy as np
import json import json
from ..utils import sh import random
from .common import HFTask, yesno from .common import HFTask
from lm_eval.base import MultipleChoiceTask, rf, mean from lm_eval.base import rf, mean
import zipfile
class Pubmed_QA(HFTask): class Pubmed_QA(HFTask):
...@@ -23,6 +18,11 @@ class Pubmed_QA(HFTask): ...@@ -23,6 +18,11 @@ class Pubmed_QA(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def test_docs(self):
if self.has_test_docs():
# HF is labelled as train but its really just for testing
return self.data["train"]
def fewshot_description(self): def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9 # Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window # 2 few-shot exmamples pushes it beyond context window
...@@ -39,6 +39,12 @@ class Pubmed_QA(HFTask): ...@@ -39,6 +39,12 @@ class Pubmed_QA(HFTask):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(doc["final_decision"]) return " {}".format(doc["final_decision"])
def fewshot_examples(self, k):
# Since only test docs sample from test docs
if self._training_docs is None:
self._training_docs = list(self.test_docs())
return random.sample(self._training_docs, k)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns """ Uses RequestFactory to construct Requests and returns
an iterable of Requests which will be sent to the LM. an iterable of Requests which will be sent to the LM.
...@@ -65,71 +71,3 @@ class Pubmed_QA(HFTask): ...@@ -65,71 +71,3 @@ class Pubmed_QA(HFTask):
return { return {
"acc" : True "acc" : True
} }
def test_docs(self):
if self.has_test_docs():
# HF is labelled as train but its really just for testing
return self.data["train"]
class SciQ(MultipleChoiceTask):
def download(self):
if not os.path.exists('data/sciq'):
os.mkdir('data/sciq')
sh((
"wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
))
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def _convert_standard(self, doc):
choices = [
doc["distractor1"],
doc["distractor2"],
doc["distractor3"],
doc["correct_answer"],
]
src = doc['support']
out_doc = {
"source" : src,
"query" : doc['question'],
"choices" : choices,
"gold" : 3,
}
return out_doc
def load_docs(self, textfilename):
with open(textfilename, 'r') as j:
docs = json.loads(j.read())
for record in docs:
yield self._convert_standard(record)
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
def validation_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
def test_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc):
return " {}\n{}".format(doc["source"], doc["query"])
class EmrQA():
def load_docs(self, textfilename):
pass
import os
import json
from ..utils import sh
from lm_eval.base import MultipleChoiceTask, rf, mean
import zipfile
class SciQ(MultipleChoiceTask):
# Multiple languages and multiple years
def download(self):
if not os.path.exists('data/sciq'):
os.mkdir('data/sciq')
sh((
"wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
))
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def _convert_standard(self, doc):
choices = [
doc["distractor1"],
doc["distractor2"],
doc["distractor3"],
doc["correct_answer"],
]
src = doc['support']
out_doc = {
"source" : src,
"query" : doc['question'],
"choices" : choices,
"gold" : 3,
}
return out_doc
def load_docs(self, textfilename):
with open(textfilename, 'r') as j:
docs = json.loads(j.read())
for record in docs:
yield self._convert_standard(record)
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
def validation_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
def test_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc):
return " {}\n{}".format(doc["source"], doc["query"])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment