Commit ec16d81d authored by Leo Gao's avatar Leo Gao
Browse files

Factor _convert_standard into MultipleChoiceTask

parent 44f03593
...@@ -28,22 +28,6 @@ class ARCEasy(HFTask, MultipleChoiceTask): ...@@ -28,22 +28,6 @@ class ARCEasy(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out description # TODO: figure out description
return "" return ""
......
...@@ -25,21 +25,24 @@ class HFTask(Task): ...@@ -25,21 +25,24 @@ class HFTask(Task):
"""Whether the task has a test set""" """Whether the task has a test set"""
return True if "test" in self.data.keys() else False return True if "test" in self.data.keys() else False
def _convert_standard(self, doc):
return doc
def training_docs(self): def training_docs(self):
# Cache training for faster few-shot. # Cache training for faster few-shot.
# If data is too large to fit in memory, override this method. # If data is too large to fit in memory, override this method.
if self.has_training_docs(): if self.has_training_docs():
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.data["train"]) self._training_docs = list(map(self._convert_standard(self.data["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["validation"] return map(self._convert_standard(self.data["validation"])
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test"] return map(self._convert_standard(self.data["test"])
def yesno(x): def yesno(x):
......
...@@ -24,22 +24,6 @@ class HeadQA(HFTask, MultipleChoiceTask): ...@@ -24,22 +24,6 @@ class HeadQA(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
def _load_docs(self, docs):
for doc in docs:
yield self._convert_standard(doc)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out description # TODO: figure out description
return "" return ""
......
...@@ -34,18 +34,6 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -34,18 +34,6 @@ class HellaSwag(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def fewshot_description(self): def fewshot_description(self):
return "Label for the relevant action: Sentences describing the " \ return "Label for the relevant action: Sentences describing the " \
"context, with an incomplete sentence trailing\nanswer that " \ "context, with an incomplete sentence trailing\nanswer that " \
......
...@@ -28,22 +28,6 @@ class MathQA(HFTask, MultipleChoiceTask): ...@@ -28,22 +28,6 @@ class MathQA(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out description # TODO: figure out description
return "" return ""
......
...@@ -24,22 +24,6 @@ class OpenBookQA(HFTask, MultipleChoiceTask): ...@@ -24,22 +24,6 @@ class OpenBookQA(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment