Commit 39c8c87f authored by Jon Tow's avatar Jon Tow
Browse files

Adopt `_process_doc` method

parent d79a4389
...@@ -386,8 +386,26 @@ class Task(abc.ABC): ...@@ -386,8 +386,26 @@ class Task(abc.ABC):
""" Downloads and returns the task dataset. """ Downloads and returns the task dataset.
Override this method to download the dataset from a custom API. Override this method to download the dataset from a custom API.
:param load_dataset_kwargs: Extra kwargs to pass to `datasets.load_dataset` :param data_dir: str
if needed. Stores the path to a local folder containing the `Task`'s data files.
Use this to specify the path to manually downloaded data (usually when
the dataset is not publicly accessible).
:param cache_dir: str
The directory to read/write the `Task` dataset. This follows the
HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode
How to treat pre-existing `Task` downloads and data.
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
Reuse download and reuse dataset.
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
Reuse download with fresh dataset.
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
""" """
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
...@@ -433,6 +451,17 @@ class Task(abc.ABC): ...@@ -433,6 +451,17 @@ class Task(abc.ABC):
""" """
return [] return []
def _process_doc(self, doc):
"""
Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split.
E.g. `map(self._process_doc, self.validation_docs)`
:return: dict
The processed version of the specified `doc`.
"""
return doc
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.training_docs()) self._training_docs = list(self.training_docs())
......
...@@ -42,16 +42,16 @@ class ARCEasy(MultipleChoiceTask): ...@@ -42,16 +42,16 @@ class ARCEasy(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._convert_standard, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc): def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one # NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters. # of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"} num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
......
...@@ -9,7 +9,6 @@ appear in a conversation. ...@@ -9,7 +9,6 @@ appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/ Homepage: https://stanfordnlp.github.io/coqa/
""" """
import json
import inspect import inspect
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa import lm_eval.datasets.coqa.coqa
......
...@@ -53,13 +53,13 @@ class DROP(Task): ...@@ -53,13 +53,13 @@ class DROP(Task):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def _convert_standard(self, doc): def _process_doc(self, doc):
return { return {
"id": doc["query_id"], "id": doc["query_id"],
"passage": doc["passage"], "passage": doc["passage"],
......
...@@ -40,16 +40,16 @@ class HeadQABase(MultipleChoiceTask): ...@@ -40,16 +40,16 @@ class HeadQABase(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._convert_standard, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc): def _process_doc(self, doc):
out_doc = { out_doc = {
"id": doc["qid"], "id": doc["qid"],
"query": "Question: " + doc["qtext"] + "\nAnswer:", "query": "Question: " + doc["qtext"] + "\nAnswer:",
......
...@@ -43,13 +43,13 @@ class HellaSwag(MultipleChoiceTask): ...@@ -43,13 +43,13 @@ class HellaSwag(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def _convert_standard(self, doc): def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = { out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx), "query": self.preprocess(doc['activity_label'] + ': ' + ctx),
......
...@@ -279,7 +279,7 @@ class EthicsUtilitarianism(Ethics): ...@@ -279,7 +279,7 @@ class EthicsUtilitarianism(Ethics):
def training_docs(self): def training_docs(self):
rnd = random.Random() rnd = random.Random()
for doc in self.dataset["train"]: for doc in self.dataset["train"]:
yield self.process_doc(doc, rnd) yield self._process_doc(doc, rnd)
def validation_docs(self): def validation_docs(self):
raise NotImplementedError raise NotImplementedError
...@@ -287,9 +287,9 @@ class EthicsUtilitarianism(Ethics): ...@@ -287,9 +287,9 @@ class EthicsUtilitarianism(Ethics):
def test_docs(self): def test_docs(self):
rnd = random.Random() rnd = random.Random()
for doc in self.dataset["test"]: for doc in self.dataset["test"]:
yield self.process_doc(doc, rnd) yield self._process_doc(doc, rnd)
def process_doc(self, doc, rnd): def _process_doc(self, doc, rnd):
rnd.seed(doc["activity"]) rnd.seed(doc["activity"])
scenarios = [doc["activity"], doc["baseline"]] scenarios = [doc["activity"], doc["baseline"]]
ordering = [0, 1] ordering = [0, 1]
...@@ -336,7 +336,7 @@ class EthicsVirtue(Ethics): ...@@ -336,7 +336,7 @@ class EthicsVirtue(Ethics):
VERSION = 0 VERSION = 0
DATASET_NAME = "virtue" DATASET_NAME = "virtue"
def process_doc(self, doc): def _process_doc(self, doc):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
......
...@@ -73,12 +73,12 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -73,12 +73,12 @@ class GeneralHendrycksTest(MultipleChoiceTask):
return True return True
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._convert_standard, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc): def _process_doc(self, doc):
def format_example(doc, keys): def format_example(doc, keys):
""" """
Question: <prompt> Question: <prompt>
...@@ -105,7 +105,7 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -105,7 +105,7 @@ class GeneralHendrycksTest(MultipleChoiceTask):
# in the same distribution as val/test but auxiliary_train isn't # in the same distribution as val/test but auxiliary_train isn't
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._convert_standard, self.dataset["dev"])) self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k) return rnd.sample(list(self._fewshot_docs), k)
......
...@@ -42,15 +42,17 @@ class LogiQA(MultipleChoiceTask): ...@@ -42,15 +42,17 @@ class LogiQA(MultipleChoiceTask):
return True return True
def training_docs(self): def training_docs(self):
return map(self._convert_standard, self.dataset["train"]) if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._convert_standard, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc): def _process_doc(self, doc):
def format_example(doc, choices): def format_example(doc, choices):
""" """
Passage: <passage> Passage: <passage>
......
...@@ -40,16 +40,16 @@ class MathQA(MultipleChoiceTask): ...@@ -40,16 +40,16 @@ class MathQA(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._convert_standard, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc): def _process_doc(self, doc):
answer_idx = ['a', 'b', 'c', 'd', 'e'].index(doc['correct']) answer_idx = ['a', 'b', 'c', 'd', 'e'].index(doc['correct'])
choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", doc['options'])] choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", doc['options'])]
......
...@@ -43,16 +43,16 @@ class OpenBookQA(MultipleChoiceTask): ...@@ -43,16 +43,16 @@ class OpenBookQA(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._convert_standard, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc): def _process_doc(self, doc):
out_doc = { out_doc = {
"id": doc["id"], "id": doc["id"],
"query": doc["question_stem"], "query": doc["question_stem"],
......
...@@ -42,13 +42,13 @@ class PiQA(MultipleChoiceTask): ...@@ -42,13 +42,13 @@ class PiQA(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def _convert_standard(self, doc): def _process_doc(self, doc):
out_doc = { out_doc = {
"goal": doc["goal"], "goal": doc["goal"],
"choices": [doc["sol1"], doc["sol2"]], "choices": [doc["sol1"], doc["sol2"]],
......
...@@ -50,7 +50,7 @@ class PROST(MultipleChoiceTask): ...@@ -50,7 +50,7 @@ class PROST(MultipleChoiceTask):
return True return True
def test_docs(self): def test_docs(self):
return map(self._convert_standard, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.' assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
...@@ -61,7 +61,7 @@ class PROST(MultipleChoiceTask): ...@@ -61,7 +61,7 @@ class PROST(MultipleChoiceTask):
description=description description=description
) )
def _convert_standard(self, doc): def _process_doc(self, doc):
out_doc = { out_doc = {
"query": f"{doc['context']}\nQuestion: {doc['ex_question']}\nAnswer:", "query": f"{doc['context']}\nQuestion: {doc['ex_question']}\nAnswer:",
"choices": [doc['A'], doc['B'], doc['C'], doc['D']], "choices": [doc['A'], doc['B'], doc['C'], doc['D']],
......
...@@ -42,9 +42,9 @@ class QA4MRE(MultipleChoiceTask): ...@@ -42,9 +42,9 @@ class QA4MRE(MultipleChoiceTask):
def test_docs(self): def test_docs(self):
# `qa4mre` only has train data so we use it for the test docs. # `qa4mre` only has train data so we use it for the test docs.
return map(self._convert_standard, self.dataset["train"]) return map(self._process_doc, self.dataset["train"])
def _convert_standard(self, doc): def _process_doc(self, doc):
choices = doc["answer_options"]["answer_str"] choices = doc["answer_options"]["answer_str"]
out_doc = { out_doc = {
"source": doc["document_str"].strip().replace("\'", "'"), "source": doc["document_str"].strip().replace("\'", "'"),
......
...@@ -137,13 +137,13 @@ class QASPER(Task): ...@@ -137,13 +137,13 @@ class QASPER(Task):
def training_docs(self): def training_docs(self):
for doc in self.dataset["train"]: for doc in self.dataset["train"]:
yield from self.process_doc(doc) yield from self._process_doc(doc)
def validation_docs(self): def validation_docs(self):
for doc in self.dataset["validation"]: for doc in self.dataset["validation"]:
yield from self.process_doc(doc) yield from self._process_doc(doc)
def process_doc(self, doc): def _process_doc(self, doc):
"""Given a `doc`, flatten it out so that each JSON blob """Given a `doc`, flatten it out so that each JSON blob
contains exactly one question and one answer. Logic taken from contains exactly one question and one answer. Logic taken from
the reference implementation available at the reference implementation available at
......
...@@ -41,16 +41,16 @@ class QuAC(Task): ...@@ -41,16 +41,16 @@ class QuAC(Task):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
raise NotImplementedError("QuAC has no test docs.") raise NotImplementedError("QuAC has no test docs.")
def _convert_standard(self, doc): def _process_doc(self, doc):
doc["title"] = doc['title'] + ' - ' + doc['section_title'] doc["title"] = doc['title'] + ' - ' + doc['section_title']
return doc return doc
......
...@@ -52,12 +52,12 @@ class SATAnalogies(MultipleChoiceTask): ...@@ -52,12 +52,12 @@ class SATAnalogies(MultipleChoiceTask):
return [] return []
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return [] return []
def _convert_standard(self, doc): def _process_doc(self, doc):
return { return {
'source': doc['source'], 'source': doc['source'],
'query': doc['stem'].split(' ')[:2], 'query': doc['stem'].split(' ')[:2],
......
...@@ -38,16 +38,16 @@ class SciQ(MultipleChoiceTask): ...@@ -38,16 +38,16 @@ class SciQ(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._convert_standard, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc): def _process_doc(self, doc):
choices = [ choices = [
doc["distractor1"], doc["distractor1"],
doc["distractor2"], doc["distractor2"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment