Commit ec16d81d authored by Leo Gao's avatar Leo Gao
Browse files

Factor _convert_standard into MultipleChoiceTask

parent 44f03593
......@@ -28,22 +28,6 @@ class ARCEasy(HFTask, MultipleChoiceTask):
}
return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self):
# TODO: figure out description
return ""
......
......@@ -25,21 +25,24 @@ class HFTask(Task):
"""Whether the task has a test set"""
return True if "test" in self.data.keys() else False
def _convert_standard(self, doc):
return doc
def training_docs(self):
# Cache training for faster few-shot.
# If data is too large to fit in memory, override this method.
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.data["train"])
self._training_docs = list(map(self._convert_standard(self.data["train"]))
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.data["validation"]
return map(self._convert_standard(self.data["validation"])
def test_docs(self):
if self.has_test_docs():
return self.data["test"]
return map(self._convert_standard(self.data["test"])
def yesno(x):
......
......@@ -24,22 +24,6 @@ class HeadQA(HFTask, MultipleChoiceTask):
}
return out_doc
def _load_docs(self, docs):
for doc in docs:
yield self._convert_standard(doc)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self):
# TODO: figure out description
return ""
......
......@@ -34,18 +34,6 @@ class HellaSwag(HFTask, MultipleChoiceTask):
}
return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def fewshot_description(self):
return "Label for the relevant action: Sentences describing the " \
"context, with an incomplete sentence trailing\nanswer that " \
......
......@@ -28,22 +28,6 @@ class MathQA(HFTask, MultipleChoiceTask):
}
return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self):
# TODO: figure out description
return ""
......
......@@ -24,22 +24,6 @@ class OpenBookQA(HFTask, MultipleChoiceTask):
}
return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment