Commit 620c0d26 authored by jeffhsu3's avatar jeffhsu3
Browse files

sciq

parent 9b933d96
...@@ -66,6 +66,11 @@ class Pubmed_QA(HFTask): ...@@ -66,6 +66,11 @@ class Pubmed_QA(HFTask):
"acc" : True "acc" : True
} }
def test_docs(self):
if self.has_test_docs():
# HF is labelled as train but its really just for testing
return self.data["train"]
class SciQ(MultipleChoiceTask): class SciQ(MultipleChoiceTask):
def download(self): def download(self):
...@@ -86,7 +91,7 @@ class SciQ(MultipleChoiceTask): ...@@ -86,7 +91,7 @@ class SciQ(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_strandard(doc): def _convert_standard(self, doc):
choices = [ choices = [
doc["distractor1"], doc["distractor1"],
doc["distractor2"], doc["distractor2"],
...@@ -103,11 +108,10 @@ class SciQ(MultipleChoiceTask): ...@@ -103,11 +108,10 @@ class SciQ(MultipleChoiceTask):
return out_doc return out_doc
def load_docs(self, textfilename): def load_docs(self, textfilename):
if labelfilename != None:
with open(textfilename, 'r') as j: with open(textfilename, 'r') as j:
docs = json.loads(j.read()) docs = json.loads(j.read())
for record in docs: for record in docs:
yield _convert_standard(record) yield self._convert_standard(record)
def fewshot_description(self): def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9 # Average ctx length in labelled dataset is 238.9
...@@ -115,13 +119,13 @@ class SciQ(MultipleChoiceTask): ...@@ -115,13 +119,13 @@ class SciQ(MultipleChoiceTask):
return "" return ""
def training_docs(self): def training_docs(self):
return self.load_docs("data/sciq/Sci-Q\ dataset-2\ 3/train.json") return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
def validation_docs(self): def validation_docs(self):
return self.load_docs("data/sciq/Sci-Q\ dataset-2\ 3/valid.json") return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
def test_docs(self): def test_docs(self):
return self.load_docs("data/sciq/Sci-Q\ dataset-2\ 3/test.json") return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc): def doc_to_text(self, doc):
return " {}\n{}".format(doc["source"], doc["query"]) return " {}\n{}".format(doc["source"], doc["query"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment