Commit ab41f089 authored by jeffhsu3's avatar jeffhsu3
Browse files

merge

parents 695d633d 30c2fe23
...@@ -71,3 +71,74 @@ class Pubmed_QA(HFTask): ...@@ -71,3 +71,74 @@ class Pubmed_QA(HFTask):
return { return {
"acc" : True "acc" : True
} }
<<<<<<< HEAD
=======
def test_docs(self):
if self.has_test_docs():
# HF is labelled as train but its really just for testing
return self.data["train"]
class SciQ(MultipleChoiceTask):
def download(self):
if not os.path.exists('data/sciq'):
os.mkdir('data/sciq')
sh((
"wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
))
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def _convert_standard(self, doc):
choices = [
doc["distractor1"],
doc["distractor2"],
doc["distractor3"],
doc["correct_answer"],
]
src = doc['support']
out_doc = {
"source" : src,
"query" : doc['question'],
"choices" : choices,
"gold" : 3,
}
return out_doc
def load_docs(self, textfilename):
with open(textfilename, 'r') as j:
docs = json.loads(j.read())
for record in docs:
yield self._convert_standard(record)
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
def validation_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
def test_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc):
return "{}\n{}".format(doc["source"], doc["query"])
class EmrQA():
def load_docs(self, textfilename):
pass
>>>>>>> 30c2fe23657981d8a6155da1bd8e8098487b771a
...@@ -62,4 +62,4 @@ class SciQ(MultipleChoiceTask): ...@@ -62,4 +62,4 @@ class SciQ(MultipleChoiceTask):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json") return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc): def doc_to_text(self, doc):
return " {}\n{}".format(doc["source"], doc["query"]) return "{}\n{}".format(doc["source"], doc["query"])
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment