Commit 39c8c87f authored by Jon Tow's avatar Jon Tow
Browse files

Adopt `_process_doc` method

parent d79a4389
......@@ -386,8 +386,26 @@ class Task(abc.ABC):
""" Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
:param load_dataset_kwargs: Extra kwargs to pass to `datasets.load_dataset`
if needed.
:param data_dir: str
Stores the path to a local folder containing the `Task`'s data files.
Use this to specify the path to manually downloaded data (usually when
the dataset is not publicly accessible).
:param cache_dir: str
The directory to read/write the `Task` dataset. This follows the
HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode
How to treat pre-existing `Task` downloads and data.
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
Reuse download and reuse dataset.
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
Reuse download with fresh dataset.
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
......@@ -433,6 +451,17 @@ class Task(abc.ABC):
"""
return []
def _process_doc(self, doc):
"""
Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split.
E.g. `map(self._process_doc, self.validation_docs)`
:return: dict
The processed version of the specified `doc`.
"""
return doc
def fewshot_examples(self, k, rnd):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
......
......@@ -42,16 +42,16 @@ class ARCEasy(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
......
......@@ -9,7 +9,6 @@ appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/
"""
import json
import inspect
import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa
......
......@@ -53,13 +53,13 @@ class DROP(Task):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
......
......@@ -40,16 +40,16 @@ class HeadQABase(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
out_doc = {
"id": doc["qid"],
"query": "Question: " + doc["qtext"] + "\nAnswer:",
......
......@@ -43,13 +43,13 @@ class HellaSwag(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx),
......
......@@ -279,7 +279,7 @@ class EthicsUtilitarianism(Ethics):
def training_docs(self):
rnd = random.Random()
for doc in self.dataset["train"]:
yield self.process_doc(doc, rnd)
yield self._process_doc(doc, rnd)
def validation_docs(self):
raise NotImplementedError
......@@ -287,9 +287,9 @@ class EthicsUtilitarianism(Ethics):
def test_docs(self):
rnd = random.Random()
for doc in self.dataset["test"]:
yield self.process_doc(doc, rnd)
yield self._process_doc(doc, rnd)
def process_doc(self, doc, rnd):
def _process_doc(self, doc, rnd):
rnd.seed(doc["activity"])
scenarios = [doc["activity"], doc["baseline"]]
ordering = [0, 1]
......@@ -336,7 +336,7 @@ class EthicsVirtue(Ethics):
VERSION = 0
DATASET_NAME = "virtue"
def process_doc(self, doc):
def _process_doc(self, doc):
return doc
def doc_to_text(self, doc):
......
......@@ -73,12 +73,12 @@ class GeneralHendrycksTest(MultipleChoiceTask):
return True
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
def format_example(doc, keys):
"""
Question: <prompt>
......@@ -105,7 +105,7 @@ class GeneralHendrycksTest(MultipleChoiceTask):
# in the same distribution as val/test but auxiliary_train isn't
if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._convert_standard, self.dataset["dev"]))
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k)
......
......@@ -42,15 +42,17 @@ class LogiQA(MultipleChoiceTask):
return True
def training_docs(self):
return map(self._convert_standard, self.dataset["train"])
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
def format_example(doc, choices):
"""
Passage: <passage>
......
......@@ -40,16 +40,16 @@ class MathQA(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
answer_idx = ['a', 'b', 'c', 'd', 'e'].index(doc['correct'])
choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", doc['options'])]
......
......@@ -43,16 +43,16 @@ class OpenBookQA(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
out_doc = {
"id": doc["id"],
"query": doc["question_stem"],
......
......@@ -42,13 +42,13 @@ class PiQA(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
out_doc = {
"goal": doc["goal"],
"choices": [doc["sol1"], doc["sol2"]],
......
......@@ -50,7 +50,7 @@ class PROST(MultipleChoiceTask):
return True
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
......@@ -61,7 +61,7 @@ class PROST(MultipleChoiceTask):
description=description
)
def _convert_standard(self, doc):
def _process_doc(self, doc):
out_doc = {
"query": f"{doc['context']}\nQuestion: {doc['ex_question']}\nAnswer:",
"choices": [doc['A'], doc['B'], doc['C'], doc['D']],
......
......@@ -42,9 +42,9 @@ class QA4MRE(MultipleChoiceTask):
def test_docs(self):
# `qa4mre` only has train data so we use it for the test docs.
return map(self._convert_standard, self.dataset["train"])
return map(self._process_doc, self.dataset["train"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
choices = doc["answer_options"]["answer_str"]
out_doc = {
"source": doc["document_str"].strip().replace("\'", "'"),
......
......@@ -137,13 +137,13 @@ class QASPER(Task):
def training_docs(self):
for doc in self.dataset["train"]:
yield from self.process_doc(doc)
yield from self._process_doc(doc)
def validation_docs(self):
for doc in self.dataset["validation"]:
yield from self.process_doc(doc)
yield from self._process_doc(doc)
def process_doc(self, doc):
def _process_doc(self, doc):
"""Given a `doc`, flatten it out so that each JSON blob
contains exactly one question and one answer. Logic taken from
the reference implementation available at
......
......@@ -41,16 +41,16 @@ class QuAC(Task):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
raise NotImplementedError("QuAC has no test docs.")
def _convert_standard(self, doc):
def _process_doc(self, doc):
doc["title"] = doc['title'] + ' - ' + doc['section_title']
return doc
......
......@@ -52,12 +52,12 @@ class SATAnalogies(MultipleChoiceTask):
return []
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return []
def _convert_standard(self, doc):
def _process_doc(self, doc):
return {
'source': doc['source'],
'query': doc['stem'].split(' ')[:2],
......
......@@ -38,16 +38,16 @@ class SciQ(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
choices = [
doc["distractor1"],
doc["distractor2"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment