Unverified Commit 11f614b0 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge branch 'master' into task_doc

parents 0a6a9b7e e00d682f
...@@ -8,7 +8,8 @@ even for highly specialized humans. ...@@ -8,7 +8,8 @@ even for highly specialized humans.
Homepage: https://aghie.github.io/head-qa/ Homepage: https://aghie.github.io/head-qa/
""" """
from . common import HFTask import inspect
import lm_eval.datasets.headqa.headqa
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
...@@ -24,9 +25,9 @@ _CITATION = """ ...@@ -24,9 +25,9 @@ _CITATION = """
""" """
class HeadQABase(HFTask, MultipleChoiceTask): class HeadQABase(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "head_qa" DATASET_PATH = inspect.getfile(lm_eval.datasets.headqa.headqa)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask): ...@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
out_doc = { out_doc = {
"id": doc["qid"], "id": doc["qid"],
"query": "Question: " + doc["qtext"] + "\nAnswer:", "query": "Question: " + doc["qtext"] + "\nAnswer:",
...@@ -49,16 +61,25 @@ class HeadQABase(HFTask, MultipleChoiceTask): ...@@ -49,16 +61,25 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
class HeadQAEn(HeadQABase): class HeadQAEn(HeadQABase):
DATASET_NAME = "en" DATASET_NAME = "en"
class HeadQAEs(HeadQABase): class HeadQAEs(HeadQABase):
DATASET_NAME = "es" DATASET_NAME = "es"
# for backwards compatibility # for backwards compatibility
class HeadQAEsDeprecated(HeadQABase): class HeadQAEsDeprecated(HeadQABase):
DATASET_NAME = "es" DATASET_NAME = "es"
def __init__(self): def __init__(self):
super().__init__() super().__init__()
print("WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info.") print("WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info.")
\ No newline at end of file
...@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/ ...@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/
""" """
import re import re
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -28,7 +27,7 @@ _CITATION = """ ...@@ -28,7 +27,7 @@ _CITATION = """
""" """
class HellaSwag(HFTask, MultipleChoiceTask): class HellaSwag(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "hellaswag" DATASET_PATH = "hellaswag"
DATASET_NAME = None DATASET_NAME = None
...@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
@classmethod def training_docs(self):
def preprocess(cls, text): if self._training_docs is None:
text = text.strip() self._training_docs = list(map(self._process_doc, self.dataset["train"]))
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. return self._training_docs
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text) def validation_docs(self):
text = text.replace(" ", " ") return map(self._process_doc, self.dataset["validation"])
return text
def _convert_standard(self, doc): def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = { out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx), "query": self.preprocess(doc['activity_label'] + ': ' + ctx),
...@@ -60,5 +58,20 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -60,5 +58,20 @@ class HellaSwag(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
...@@ -14,17 +14,14 @@ tasks are refered to in this work as the `em` sub-metric. See Section 3. Metrics ...@@ -14,17 +14,14 @@ tasks are refered to in this work as the `em` sub-metric. See Section 3. Metrics
of the paper. of the paper.
Homepage: https://github.com/hendrycks/ethics Homepage: https://github.com/hendrycks/ethics
""" """
import abc import abc
import csv
import os
import random import random
import inspect
import lm_eval.datasets.hendrycks_ethics.hendrycks_ethics
import numpy as np import numpy as np
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean, yesno
from lm_eval.utils import sh
from .common import yesno
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -38,15 +35,8 @@ _CITATION = """ ...@@ -38,15 +35,8 @@ _CITATION = """
class Ethics(Task): class Ethics(Task):
def download(self): DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_ethics.hendrycks_ethics)
if not os.path.exists('data/ethics/done'): DATASET_NAME = None
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/ethics.tar", local_file="data/ethics.tar", expected_checksum="40acbf1ac0da79a2aabef394d58889136b8d38b05be09482006de2453fb06333")
sh("""
tar -xf data/ethics.tar -C data/
rm data/ethics.tar
touch data/ethics/done
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -57,30 +47,16 @@ class Ethics(Task): ...@@ -57,30 +47,16 @@ class Ethics(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
@abc.abstractmethod
def process_doc(self, doc):
pass
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return self.process_doc(list(filereader))
@abc.abstractmethod
def get_prefix(self):
"""returns string corresponding to file prefix"""
pass
# TODO: Figure out how to incorporate the Ethics `hard` test sets. # TODO: Figure out how to incorporate the Ethics `hard` test sets.
def training_docs(self): def training_docs(self):
return self.load_doc(f"data/ethics/{self.get_prefix()}_train.csv") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
raise NotImplementedError raise NotImplementedError
def test_docs(self): def test_docs(self):
return self.load_doc(f"data/ethics/{self.get_prefix()}_test.csv") return self.dataset["test"]
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -109,18 +85,19 @@ class Ethics(Task): ...@@ -109,18 +85,19 @@ class Ethics(Task):
class EthicsCM(Ethics): class EthicsCM(Ethics):
VERSION = 0 VERSION = 0
# Ignoring "ambiguous" extra dataset for now DATASET_NAME = "commonsense" # Ignoring "ambiguous" extra dataset for now
def get_prefix(self):
return "commonsense/cm"
def process_doc(self, doc):
return doc[1:]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Is this wrong?\nAnswer:".format(doc[1]) return "{}\nQuestion: Is this wrong?\nAnswer:".format(doc["input"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["input"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(yesno(int(doc[0]))) return " {}".format(yesno(int(doc["label"])))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
...@@ -130,7 +107,7 @@ class EthicsCM(Ethics): ...@@ -130,7 +107,7 @@ class EthicsCM(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold "acc": pred == gold
} }
...@@ -148,19 +125,20 @@ class EthicsCM(Ethics): ...@@ -148,19 +125,20 @@ class EthicsCM(Ethics):
class EthicsDeontology(Ethics): class EthicsDeontology(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "deontology"
return "deontology/deontology"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
def doc_to_text(self, doc): def doc_to_text(self, doc):
prompt = " ".join([doc[1], doc[2]]) prompt = " ".join([doc["scenario"], doc["excuse"]])
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt) return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return " ".join([doc["scenario"], doc["excuse"]])
def doc_to_target(self, doc): def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc[0])] target = ["unreasonable", "reasonable"][int(doc["label"])]
return " {}".format(target) return " {}".format(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -170,14 +148,15 @@ class EthicsDeontology(Ethics): ...@@ -170,14 +148,15 @@ class EthicsDeontology(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct # Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)] em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))] em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
...@@ -198,18 +177,19 @@ class EthicsDeontology(Ethics): ...@@ -198,18 +177,19 @@ class EthicsDeontology(Ethics):
class EthicsJustice(Ethics): class EthicsJustice(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "justice"
return "justice/justice"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches later on & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc[1]) return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc["scenario"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["scenario"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc[0])] target = ["unreasonable", "reasonable"][int(doc["label"])]
return " {}".format(target) return " {}".format(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -219,14 +199,15 @@ class EthicsJustice(Ethics): ...@@ -219,14 +199,15 @@ class EthicsJustice(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct # Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)] em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))] em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
...@@ -247,17 +228,12 @@ class EthicsJustice(Ethics): ...@@ -247,17 +228,12 @@ class EthicsJustice(Ethics):
class EthicsUtilitarianismOriginal(Ethics): class EthicsUtilitarianismOriginal(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "utilitarianism"
return "utilitarianism/util"
def has_training_docs(self): def has_training_docs(self):
# Rely on the fixed and labeled examples of `fewshot_examples` for the few-shot setting. # Rely on the fixed and labeled examples of `fewshot_examples` for the few-shot setting.
return False return False
def process_doc(self, docs):
for doc in docs:
yield {"activity": doc[0], "baseline": doc[1], "rating": ""}
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5 # Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more." assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more."
...@@ -274,6 +250,12 @@ class EthicsUtilitarianismOriginal(Ethics): ...@@ -274,6 +250,12 @@ class EthicsUtilitarianismOriginal(Ethics):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Activity: "{}"\nRating:'.format(doc["activity"]) return 'Activity: "{}"\nRating:'.format(doc["activity"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["activity"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc["rating"] return " " + doc["rating"]
...@@ -311,25 +293,34 @@ class EthicsUtilitarianismOriginal(Ethics): ...@@ -311,25 +293,34 @@ class EthicsUtilitarianismOriginal(Ethics):
class EthicsUtilitarianism(Ethics): class EthicsUtilitarianism(Ethics):
VERSION = 0
""" """
This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared. This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared.
This allows scaling to >5 shots. This allows scaling to >5 shots.
""" """
VERSION = 0
DATASET_NAME = "utilitarianism"
def get_prefix(self): def training_docs(self):
return "utilitarianism/util" for doc in self.dataset["train"]:
yield self._process_doc(doc)
def process_doc(self, docs): def validation_docs(self):
rnd = random.Random() raise NotImplementedError
for doc in docs:
rnd.seed(doc[0]) def test_docs(self):
ordering = [0, 1] for doc in self.dataset["test"]:
rnd.shuffle(ordering) yield self._process_doc(doc)
yield {
"scenarios": [doc[ordering[0]], doc[ordering[1]]], def _process_doc(self, doc):
"label": int(ordering.index(0) == 0), # The correct scenario is always first rnd = random.Random(doc["activity"])
} scenarios = [doc["activity"], doc["baseline"]]
ordering = [0, 1]
rnd.shuffle(ordering)
return {
"scenarios": [scenarios[ordering[0]], scenarios[ordering[1]]],
# The correct scenario is always first
"label": int(ordering.index(0) == 0),
}
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Scenario 1: {}\nScenario 2: {}\nQuestion: Is Scenario 1 preferrable?\nAnswer:".format( return "Scenario 1: {}\nScenario 2: {}\nQuestion: Is Scenario 1 preferrable?\nAnswer:".format(
...@@ -365,23 +356,19 @@ class EthicsUtilitarianism(Ethics): ...@@ -365,23 +356,19 @@ class EthicsUtilitarianism(Ethics):
class EthicsVirtue(Ethics): class EthicsVirtue(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "virtue"
return "virtue/virtue"
def process_doc(self, doc): def _process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers return doc
return [x + [i] for i, x in enumerate(doc[1:])]
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return self.process_doc(list(filereader))
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(*doc[1].split(" [SEP] ")) return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(
doc["scenario"],
doc["trait"]
)
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(yesno(int(doc[0]))) return " {}".format(yesno(int(doc["label"])))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
...@@ -391,14 +378,15 @@ class EthicsVirtue(Ethics): ...@@ -391,14 +378,15 @@ class EthicsVirtue(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 5 are correct # Calculate exact matches - i.e. all in a pair of 5 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)] em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)]
em_cors = [em_sums[i] == 5 for i in range(len(em_sums))] em_cors = [em_sums[i] == 5 for i in range(len(em_sums))]
......
...@@ -8,13 +8,10 @@ models to generate answer derivations and explanations. ...@@ -8,13 +8,10 @@ models to generate answer derivations and explanations.
Homepage: https://github.com/hendrycks/math Homepage: https://github.com/hendrycks/math
""" """
import abc import inspect
import json import lm_eval.datasets.hendrycks_math.hendrycks_math
from lm_eval.utils import sh
from lm_eval.metrics import mean from lm_eval.metrics import mean
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from pathlib import Path
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -28,21 +25,8 @@ _CITATION = """ ...@@ -28,21 +25,8 @@ _CITATION = """
class Math(Task): class Math(Task):
DATASET_PATH = Path('data/MATH') DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math)
DATASET_NAME = None
def download(self):
if not (self.DATASET_PATH / 'test').exists() or not (self.DATASET_PATH / 'done').exists():
sh(f"mkdir -p {self.DATASET_PATH}")
download_file("https://people.eecs.berkeley.edu/~hendrycks/MATH.tar", local_file=f"{self.DATASET_PATH}.tar", expected_checksum="0fbe4fad0df66942db6c221cdcc95b298cc7f4595a2f0f518360cce84e90d9ac")
sh(f"""
tar -xf {self.DATASET_PATH}.tar -C data/ && touch {self.DATASET_PATH / 'done'}
rm {self.DATASET_PATH}.tar
""")
@abc.abstractmethod
def get_file_info(self):
"""returns directory name"""
pass
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -53,28 +37,31 @@ class Math(Task): ...@@ -53,28 +37,31 @@ class Math(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _load_docs(self, path):
for file in sorted(path.iterdir()):
with open(file) as f:
doc = json.load(f)
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
yield doc
def training_docs(self): def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train" / self.get_file_info()) return map(self._process_doc, self.dataset["train"])
def validation_docs(self): def validation_docs(self):
return NotImplemented return NotImplemented
def test_docs(self): def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test" / self.get_file_info()) return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Problem: " + doc["problem"] + "\nAnswer:" return "Problem: " + doc["problem"] + "\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["problem"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc["answer"] return " " + doc["solution"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, ["\n"])
...@@ -301,41 +288,34 @@ class Math(Task): ...@@ -301,41 +288,34 @@ class Math(Task):
class MathAlgebra(Math): class MathAlgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'algebra'
return 'algebra'
class MathCountingAndProbability(Math): class MathCountingAndProbability(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'counting_and_probability'
return 'counting_and_probability'
class MathGeometry(Math): class MathGeometry(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'geometry'
return 'geometry'
class MathIntermediateAlgebra(Math): class MathIntermediateAlgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'intermediate_algebra'
return 'intermediate_algebra'
class MathNumberTheory(Math): class MathNumberTheory(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'number_theory'
return 'number_theory'
class MathPrealgebra(Math): class MathPrealgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'prealgebra'
return 'prealgebra'
class MathPrecalculus(Math): class MathPrecalculus(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'precalculus'
return 'precalculus'
...@@ -12,12 +12,7 @@ important shortcomings. ...@@ -12,12 +12,7 @@ important shortcomings.
Homepage: https://github.com/hendrycks/test Homepage: https://github.com/hendrycks/test
""" """
import csv
import random
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from ..utils import sh
from pathlib import Path
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -61,25 +56,15 @@ def create_task(subject): ...@@ -61,25 +56,15 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask): class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/hendrycksTest/") DATASET_PATH = "hendrycks_test"
DATASET_NAME = None
def __init__(self, subject): def __init__(self, subject):
self.subject = subject self.DATASET_NAME = subject
super().__init__() super().__init__()
def download(self):
if not (self.DATASET_PATH / 'done').exists():
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/data.tar", local_file="data/data.tar", expected_checksum="78a804365a59028188fb19bd1adcadc5e0c260b220a9d8b2e33a5ea7d5fbe3b4")
sh("""
tar -xf data/data.tar -C data/
rm data/data.tar
mv data/data data/hendrycksTest
touch data/hendrycksTest/done
""")
def has_training_docs(self): def has_training_docs(self):
return True return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def validation_docs(self):
def format_example(doc, choices): return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, keys):
""" """
Question: <prompt> Question: <prompt>
Choices: Choices:
...@@ -98,46 +89,31 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -98,46 +89,31 @@ class GeneralHendrycksTest(MultipleChoiceTask):
D. <choice4> D. <choice4>
Answer: Answer:
""" """
prompt = "Question: " + doc[0] + "\nChoices:\n" prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join([f"{choices[j]}. {doc[j+1]}\n" for j in range(4)]) prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
prompt += "Answer:" prompt += "Answer:"
return prompt return prompt
choices = ['A', 'B', 'C', 'D'] keys = ['A', 'B', 'C', 'D']
return { return {
"query": format_example(doc, choices), "query": format_example(doc, keys),
"choices": doc[1:5], "choices": doc["choices"],
"gold": choices.index(doc[5]) "gold": keys.index(doc["answer"]) if isinstance(doc["answer"], str) else doc["answer"]
} }
def _load_docs(self, filename):
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
return (self._convert_standard(doc) for doc in reader)
def training_docs(self):
docs = []
for train_dir in ["auxiliary_train", "dev"]:
for f in (self.DATASET_PATH / train_dir).iterdir():
docs.extend(self._load_docs(f))
return docs
def validation_docs(self):
filename = self.DATASET_PATH / "val" / f"{self.subject}_val.csv"
return self._load_docs(filename)
def test_docs(self):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is # fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't # in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(self._load_docs(filename)) self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k) return rnd.sample(list(self._fewshot_docs), k)
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
...@@ -12,12 +12,10 @@ in the broader discourse. ...@@ -12,12 +12,10 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
import json import inspect
import lm_eval.datasets.lambada.lambada
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
import os
_CITATION = """ _CITATION = """
...@@ -34,19 +32,7 @@ _CITATION = """ ...@@ -34,19 +32,7 @@ _CITATION = """
class LAMBADA(Task): class LAMBADA(Task):
VERSION = 0 VERSION = 0
def download(self): DATASET_PATH = inspect.getfile(lm_eval.datasets.lambada.lambada)
sh("mkdir -p data/lambada")
try:
if not os.path.exists("data/lambada/lambada_test.jsonl"):
download_file(
"http://eaidata.bmk.sh/data/lambada_test.jsonl",
local_file="data/lambada/lambada_test.jsonl",
expected_checksum="4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
)
except:
# fallback - for some reason best_download doesnt work all the time here
sh("wget http://eaidata.bmk.sh/data/lambada_test.jsonl -O data/lambada/lambada_test.jsonl")
sh('echo "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226 data/lambada/lambada_test.jsonl" | sha256sum --check')
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -61,9 +47,7 @@ class LAMBADA(Task): ...@@ -61,9 +47,7 @@ class LAMBADA(Task):
pass pass
def validation_docs(self): def validation_docs(self):
with open("data/lambada/lambada_test.jsonl") as fh: return self.dataset["validation"]
for line in fh:
yield json.loads(line)
def test_docs(self): def test_docs(self):
pass pass
...@@ -71,6 +55,12 @@ class LAMBADA(Task): ...@@ -71,6 +55,12 @@ class LAMBADA(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] return doc['text'].rsplit(' ', 1)[0]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['text']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['text'].rsplit(' ', 1)[1] return " " + doc['text'].rsplit(' ', 1)[1]
......
...@@ -13,12 +13,7 @@ in the broader discourse. ...@@ -13,12 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from lm_eval.tasks.lambada import LAMBADA from lm_eval.tasks.lambada import LAMBADA
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -35,8 +30,15 @@ _CITATION = """ ...@@ -35,8 +30,15 @@ _CITATION = """
class LAMBADA_cloze(LAMBADA): class LAMBADA_cloze(LAMBADA):
VERSION = 0 VERSION = 0
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] + " ____. ->" return doc['text'].rsplit(' ', 1)[0] + " ____. ->"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['text']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['text'].rsplit(' ', 1)[1] return " " + doc['text'].rsplit(' ', 1)[1]
...@@ -14,13 +14,6 @@ in the broader discourse. ...@@ -14,13 +14,6 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
from . import lambada from . import lambada
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
import json
from functools import partial
import os
_CITATION = """ _CITATION = """
...@@ -35,68 +28,37 @@ _CITATION = """ ...@@ -35,68 +28,37 @@ _CITATION = """
""" """
LANGS = ["en", "fr", "de", "it", "es"]
CHECKSUMS = {"en": "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226",
"fr": "941ec6a73dba7dc91c860bf493eb66a527cd430148827a4753a4535a046bf362",
"de": "51c6c1795894c46e88e4c104b5667f488efe79081fb34d746b82b8caa663865e",
"it": "86654237716702ab74f42855ae5a78455c1b0e50054a4593fb9c6fcf7fad0850",
"es": "ffd760026c647fb43c67ce1bc56fd527937304b348712dce33190ea6caba6f9c"
}
class MultilingualLAMBADA(lambada.LAMBADA): class MultilingualLAMBADA(lambada.LAMBADA):
VERSION = 0 VERSION = 0
def __init__(self, lang=None):
self.LANG = lang
super().__init__()
def download(self):
sh("mkdir -p data/lambada")
f = f"data/lambada/lambada_test_{self.LANG}.jsonl"
url = f"http://eaidata.bmk.sh/data/lambada_test_{self.LANG}.jsonl"
try:
if not os.path.exists(f):
download_file(
url,
local_file=f,
expected_checksum=CHECKSUMS[self.LANG]
)
except:
# fallback - for some reason best_download doesnt work all the time here
sh(f"wget {url} -O {f}")
sh(f'echo "{CHECKSUMS[self.LANG]} {f}" | sha256sum --check')
def validation_docs(self):
with open(f"data/lambada/lambada_test_{self.LANG}.jsonl") as fh:
for line in fh:
yield json.loads(line)
class MultilingualLAMBADAEN(MultilingualLAMBADA): class MultilingualLAMBADAEN(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'en'
super().__init__('en')
class MultilingualLAMBADAFR(MultilingualLAMBADA): class MultilingualLAMBADAFR(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'fr'
super().__init__('fr')
class MultilingualLAMBADADE(MultilingualLAMBADA): class MultilingualLAMBADADE(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'de'
super().__init__('de')
class MultilingualLAMBADAIT(MultilingualLAMBADA): class MultilingualLAMBADAIT(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'it'
super().__init__('it')
class MultilingualLAMBADAES(MultilingualLAMBADA): class MultilingualLAMBADAES(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'es'
super().__init__('es')
LANG_CLASSES = [MultilingualLAMBADAEN, MultilingualLAMBADAFR,
MultilingualLAMBADADE, MultilingualLAMBADAIT,
MultilingualLAMBADAES]
LANG_CLASSES = [MultilingualLAMBADAEN, MultilingualLAMBADAFR, MultilingualLAMBADADE, MultilingualLAMBADAIT, MultilingualLAMBADAES]
def construct_tasks(): def construct_tasks():
tasks = {} tasks = {}
for lang, lang_class in zip(LANGS, LANG_CLASSES): for lang_class in LANG_CLASSES:
tasks[f"lambada_mt_{lang}"] = lang_class tasks[f"lambada_mt_{lang_class.DATASET_NAME}"] = lang_class
return tasks return tasks
...@@ -10,9 +10,9 @@ NLP setting. ...@@ -10,9 +10,9 @@ NLP setting.
Homepage: https://github.com/lgw863/LogiQA-dataset Homepage: https://github.com/lgw863/LogiQA-dataset
""" """
import inspect
import lm_eval.datasets.logiqa.logiqa
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from best_download import download_file
from pathlib import Path
_CITATION = """ _CITATION = """
...@@ -29,21 +29,8 @@ _CITATION = """ ...@@ -29,21 +29,8 @@ _CITATION = """
class LogiQA(MultipleChoiceTask): class LogiQA(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/logiqa") DATASET_PATH = inspect.getfile(lm_eval.datasets.logiqa.logiqa)
DATASET_NAME = None
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
base_url = "https://raw.githubusercontent.com/lgw863/LogiQA-dataset/master"
splits = [
{"name": "Train", "checksum": "7d5bb1f58278e33b395744cd2ad8d7600faa0b3c4d615c659a44ec1181d759fa"},
{"name": "Eval", "checksum": "4c49e6753b7262c001506b9151135abf722247035ab075dad93acdea5789c01f"},
{"name": "Test", "checksum": "359acb78c37802208f7fde9e2f6574b8526527c63d6a336f90a53f1932cb4701"}
]
for split in splits:
file = self.DATASET_PATH / f"{split['name']}.txt"
download_file(f"{base_url}/{split['name']}.txt", local_file=str(file), expected_checksum=split["checksum"])
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -54,7 +41,18 @@ class LogiQA(MultipleChoiceTask): ...@@ -54,7 +41,18 @@ class LogiQA(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, choices): def format_example(doc, choices):
""" """
Passage: <passage> Passage: <passage>
...@@ -66,7 +64,7 @@ class LogiQA(MultipleChoiceTask): ...@@ -66,7 +64,7 @@ class LogiQA(MultipleChoiceTask):
D. <choice4> D. <choice4>
Answer: Answer:
""" """
prompt = "Passage: " + doc["passage"] + "\n" prompt = "Passage: " + doc["context"] + "\n"
prompt += "Question: " + doc["question"] + "\nChoices:\n" prompt += "Question: " + doc["question"] + "\nChoices:\n"
for choice, option in zip(choices, doc["options"]): for choice, option in zip(choices, doc["options"]):
prompt += f"{choice.upper()}. {option}\n" prompt += f"{choice.upper()}. {option}\n"
...@@ -74,35 +72,17 @@ class LogiQA(MultipleChoiceTask): ...@@ -74,35 +72,17 @@ class LogiQA(MultipleChoiceTask):
return prompt return prompt
choices = ['a', 'b', 'c', 'd'] choices = ['a', 'b', 'c', 'd']
return { return {
"passage": doc["context"], # Used for decontamination
"query": format_example(doc, choices), "query": format_example(doc, choices),
"choices": doc["options"], "choices": doc["options"],
"gold": choices.index(doc["answerKey"]) "gold": choices.index(doc["label"])
} }
def _load_docs(self, filename):
def normalize(text):
return text.replace(".", ". ").strip()
with open(filename, 'r') as f:
docs = f.read().strip().split("\n\n")
for rawdoc in docs:
rawdoc = rawdoc.split("\n")
doc = {
"answerKey": rawdoc[0].strip(),
"passage": normalize(rawdoc[1]),
"question": normalize(rawdoc[2]),
"options": [normalize(option[2:]) for option in rawdoc[3:]]
}
yield self._convert_standard(doc)
def training_docs(self):
return self._load_docs(self.DATASET_PATH / "Train.txt")
def validation_docs(self):
return self._load_docs(self.DATASET_PATH / "Eval.txt")
def test_docs(self):
return self._load_docs(self.DATASET_PATH / "Test.txt")
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["passage"]
...@@ -10,7 +10,6 @@ Homepage: https://math-qa.github.io/math-QA/ ...@@ -10,7 +10,6 @@ Homepage: https://math-qa.github.io/math-QA/
""" """
import re import re
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -25,7 +24,7 @@ _CITATION = """ ...@@ -25,7 +24,7 @@ _CITATION = """
""" """
class MathQA(HFTask, MultipleChoiceTask): class MathQA(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "math_qa" DATASET_PATH = "math_qa"
DATASET_NAME = None DATASET_NAME = None
...@@ -39,13 +38,23 @@ class MathQA(HFTask, MultipleChoiceTask): ...@@ -39,13 +38,23 @@ class MathQA(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
answer_idx = ['a', 'b', 'c', 'd', 'e'].index(doc['correct']) answer_idx = ['a', 'b', 'c', 'd', 'e'].index(doc['correct'])
choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", doc['options'])] choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", doc['options'])]
out_doc = { out_doc = {
"query": "Question: " + doc['Problem'] +"\nAnswer:", "query": "Question: " + doc['Problem'] + "\nAnswer:",
"choices": choices, "choices": choices,
"gold": answer_idx, "gold": answer_idx,
} }
...@@ -53,3 +62,9 @@ class MathQA(HFTask, MultipleChoiceTask): ...@@ -53,3 +62,9 @@ class MathQA(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
...@@ -20,9 +20,8 @@ of a question's options. See section 4 of the paper for details. ...@@ -20,9 +20,8 @@ of a question's options. See section 4 of the paper for details.
Homepage: https://leaderboard.allenai.org/mctaco/submissions/public Homepage: https://leaderboard.allenai.org/mctaco/submissions/public
""" """
import numpy as np import numpy as np
from lm_eval.base import rf
from collections import defaultdict from collections import defaultdict
from . common import HFTask from lm_eval.base import rf, Task
_CITATION = """ _CITATION = """
...@@ -35,7 +34,7 @@ _CITATION = """ ...@@ -35,7 +34,7 @@ _CITATION = """
""" """
class MCTACO(HFTask): class MCTACO(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "mc_taco" DATASET_PATH = "mc_taco"
DATASET_NAME = None DATASET_NAME = None
...@@ -49,10 +48,22 @@ class MCTACO(HFTask): ...@@ -49,10 +48,22 @@ class MCTACO(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['sentence']}\nQuestion: {doc['question']}\n"\ return f"{doc['sentence']}\nQuestion: {doc['question']}\n"\
f"Answer: {doc['answer']}\nPlausible:" f"Answer: {doc['answer']}\nPlausible:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['question'] + " " + doc['sentence']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + ["no", "yes"][doc['label']] return " " + ["no", "yes"][doc['label']]
......
...@@ -7,14 +7,11 @@ modified from Chinese high school English listening comprehension test data. ...@@ -7,14 +7,11 @@ modified from Chinese high school English listening comprehension test data.
Homepage: https://github.com/Nealcly/MuTual Homepage: https://github.com/Nealcly/MuTual
""" """
import json
import zipfile
import shutil
import numpy as np import numpy as np
from pathlib import Path import inspect
import lm_eval.datasets.mutual.mutual
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -30,29 +27,10 @@ _CITATION = """ ...@@ -30,29 +27,10 @@ _CITATION = """
class MuTualBase(Task): class MuTualBase(Task):
VERSION = 1 VERSION = 1
BASE_PATH = Path("data/mutual") DATASET_PATH = inspect.getfile(lm_eval.datasets.mutual.mutual)
DATASET_NAME = None DATASET_NAME = None
CHOICES = ['A', 'B', 'C', 'D'] CHOICES = ['A', 'B', 'C', 'D']
def __init__(self):
super().__init__()
def download(self):
if self.BASE_PATH.exists():
return
Path.mkdir(self.BASE_PATH, parents=True)
master_zip = Path("data/master.zip")
download_file(
"https://github.com/Nealcly/MuTual/archive/master.zip",
local_file=str(master_zip),
expected_checksum="bb325cf6c672f0f02699993a37138b0fa0af6fcfc77ec81dfbe46add4d7b29f9")
with zipfile.ZipFile(master_zip, 'r') as zip:
zip.extractall("data")
Path("data/MuTual-master/data").rename(str(self.BASE_PATH))
# Remove left over files and directories.
master_zip.unlink()
shutil.rmtree("data/MuTual-master")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -62,18 +40,11 @@ class MuTualBase(Task): ...@@ -62,18 +40,11 @@ class MuTualBase(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def _load_docs(self, path):
for file in sorted(path.iterdir()):
if file.suffix != ".txt":
continue
with open(file, 'r', encoding='utf-8') as f:
yield json.load(f)
def training_docs(self): def training_docs(self):
return self._load_docs(self.BASE_PATH / self.DATASET_NAME / "train") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self._load_docs(self.BASE_PATH / self.DATASET_NAME / "dev") return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
...@@ -81,6 +52,12 @@ class MuTualBase(Task): ...@@ -81,6 +52,12 @@ class MuTualBase(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.detokenize(doc["article"]) return self.detokenize(doc["article"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["article"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + self.detokenize(doc["options"][self.CHOICES.index(doc["answers"])]) return " " + self.detokenize(doc["options"][self.CHOICES.index(doc["answers"])])
...@@ -134,8 +111,8 @@ class MuTualBase(Task): ...@@ -134,8 +111,8 @@ class MuTualBase(Task):
class MuTual(MuTualBase): class MuTual(MuTualBase):
DATASET_NAME = Path("mutual") DATASET_NAME = "mutual"
class MuTualPlus(MuTualBase): class MuTualPlus(MuTualBase):
DATASET_NAME = Path("mutual_plus") DATASET_NAME = "mutual_plus"
...@@ -15,8 +15,7 @@ not even bother with the train set. ...@@ -15,8 +15,7 @@ not even bother with the train set.
Homepage: https://ai.google.com/research/NaturalQuestions Homepage: https://ai.google.com/research/NaturalQuestions
""" """
import random from lm_eval.base import Task
from . common import HFTask
from itertools import islice from itertools import islice
...@@ -30,7 +29,7 @@ _CITATION = """ ...@@ -30,7 +29,7 @@ _CITATION = """
""" """
class NaturalQs(HFTask): class NaturalQs(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "natural_questions" DATASET_PATH = "natural_questions"
DATASET_NAME = None DATASET_NAME = None
...@@ -47,7 +46,12 @@ class NaturalQs(HFTask): ...@@ -47,7 +46,12 @@ class NaturalQs(HFTask):
def training_docs(self): def training_docs(self):
# Cache training for faster few-shot. # Cache training for faster few-shot.
# Data is too large to fit in memory. # Data is too large to fit in memory.
return self.data["train"] if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# Data is too large to fit in memory. We just sample from the first bit. # Data is too large to fit in memory. We just sample from the first bit.
...@@ -59,6 +63,12 @@ class NaturalQs(HFTask): ...@@ -59,6 +63,12 @@ class NaturalQs(HFTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Q: ' + doc['question']['text'] + '\n\n' + 'A:' return 'Q: ' + doc['question']['text'] + '\n\n' + 'A:'
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['question']['text']
def doc_to_target(self, doc): def doc_to_target(self, doc):
# There's a short answer and a long answer. Based on the paper, I'm using the long answer. # There's a short answer and a long answer. Based on the paper, I'm using the long answer.
short_answer = doc['annotations']['short_answers'][0]['text'] short_answer = doc['annotations']['short_answers'][0]['text']
......
...@@ -15,7 +15,6 @@ based algorithm and a word co-occurrence algorithm. ...@@ -15,7 +15,6 @@ based algorithm and a word co-occurrence algorithm.
Homepage: https://allenai.org/data/open-book-qa Homepage: https://allenai.org/data/open-book-qa
""" """
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -28,7 +27,7 @@ _CITATION = """ ...@@ -28,7 +27,7 @@ _CITATION = """
""" """
class OpenBookQA(HFTask, MultipleChoiceTask): class OpenBookQA(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "openbookqa" DATASET_PATH = "openbookqa"
DATASET_NAME = "main" DATASET_NAME = "main"
...@@ -42,7 +41,18 @@ class OpenBookQA(HFTask, MultipleChoiceTask): ...@@ -42,7 +41,18 @@ class OpenBookQA(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
out_doc = { out_doc = {
"id": doc["id"], "id": doc["id"],
"query": doc["question_stem"], "query": doc["question_stem"],
...@@ -53,3 +63,9 @@ class OpenBookQA(HFTask, MultipleChoiceTask): ...@@ -53,3 +63,9 @@ class OpenBookQA(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
...@@ -10,15 +10,9 @@ math, computer science, and philosophy papers. ...@@ -10,15 +10,9 @@ math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/ Homepage: https://pile.eleuther.ai/
""" """
import os import inspect
import lm_eval.datasets.pile.pile
import lm_dataformat from lm_eval.base import PerplexityTask
import abc
import numpy as np
from lm_eval.base import rf, PerplexityTask
from ..metrics import mean, matthews_corrcoef, f1_score
from ..utils import general_detokenize
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -31,32 +25,10 @@ _CITATION = """ ...@@ -31,32 +25,10 @@ _CITATION = """
""" """
class PilePerplexityTask(PerplexityTask, abc.ABC): class PilePerplexityTask(PerplexityTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.pile.pile)
PILE_SET_NAME = None DATASET_NAME = None
VAL_PATH = 'data/pile/val.jsonl.zst'
TEST_PATH = 'data/pile/test.jsonl.zst'
def download(self):
# TODO: separate pile val/test out by component so we don't have to scan the entire file once per set
if not os.path.exists("data/pile/test.jsonl.zst"):
# todo use new best_download fallback api
os.makedirs("data/pile/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/pile/val.jsonl.zst", local_file=self.VAL_PATH, expected_checksum="264c875d8bbd355d8daa9d032b75fd8fb91606218bb84dd1155b203fcd5fab92")
download_file("http://eaidata.bmk.sh/data/pile/test.jsonl.zst", local_file=self.TEST_PATH, expected_checksum="0bb28c52d0b5596d389bf179ce2d43bf7f7ffae76b0d2d20b180c97f62e0975e")
def validation_docs(self):
rdr = lm_dataformat.Reader(self.VAL_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def test_docs(self):
rdr = lm_dataformat.Reader(self.TEST_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -64,90 +36,98 @@ class PilePerplexityTask(PerplexityTask, abc.ABC): ...@@ -64,90 +36,98 @@ class PilePerplexityTask(PerplexityTask, abc.ABC):
def has_test_docs(self): def has_test_docs(self):
return True return True
def validation_docs(self):
for doc in self.dataset["validation"]:
yield doc["text"]
def test_docs(self):
for doc in self.dataset["test"]:
yield doc["text"]
class PileArxiv(PilePerplexityTask): class PileArxiv(PilePerplexityTask):
PILE_SET_NAME = "ArXiv" DATASET_NAME = "pile_arxiv"
class PileBooks3(PilePerplexityTask): class PileBooks3(PilePerplexityTask):
PILE_SET_NAME = "Books3" DATASET_NAME = "pile_books3"
class PileBookCorpus2(PilePerplexityTask): class PileBookCorpus2(PilePerplexityTask):
PILE_SET_NAME = "BookCorpus2" DATASET_NAME = "pile_bookcorpus2"
class PileDmMathematics(PilePerplexityTask): class PileDmMathematics(PilePerplexityTask):
PILE_SET_NAME = "DM Mathematics" DATASET_NAME = "pile_dm-mathematics"
class PileEnron(PilePerplexityTask): class PileEnron(PilePerplexityTask):
PILE_SET_NAME = "Enron Emails" DATASET_NAME = "pile_enron"
class PileEuroparl(PilePerplexityTask): class PileEuroparl(PilePerplexityTask):
PILE_SET_NAME = "EuroParl" DATASET_NAME = "pile_europarl"
class PileFreeLaw(PilePerplexityTask): class PileFreeLaw(PilePerplexityTask):
PILE_SET_NAME = "FreeLaw" DATASET_NAME = "pile_freelaw"
class PileGithub(PilePerplexityTask): class PileGithub(PilePerplexityTask):
PILE_SET_NAME = "Github" DATASET_NAME = "pile_github"
class PileGutenberg(PilePerplexityTask): class PileGutenberg(PilePerplexityTask):
PILE_SET_NAME = "Gutenberg (PG-19)" DATASET_NAME = "pile_gutenberg"
class PileHackernews(PilePerplexityTask): class PileHackernews(PilePerplexityTask):
PILE_SET_NAME = "HackerNews" DATASET_NAME = "pile_hackernews"
class PileNIHExporter(PilePerplexityTask): class PileNIHExporter(PilePerplexityTask):
PILE_SET_NAME = "NIH ExPorter" DATASET_NAME = "pile_nih-exporter"
class PileOpenSubtitles(PilePerplexityTask): class PileOpenSubtitles(PilePerplexityTask):
PILE_SET_NAME = "OpenSubtitles" DATASET_NAME = "pile_opensubtitles"
class PileOpenWebText2(PilePerplexityTask): class PileOpenWebText2(PilePerplexityTask):
PILE_SET_NAME = "OpenWebText2" DATASET_NAME = "pile_openwebtext2"
class PilePhilPapers(PilePerplexityTask): class PilePhilPapers(PilePerplexityTask):
PILE_SET_NAME = "PhilPapers" DATASET_NAME = "pile_philpapers"
class PilePileCc(PilePerplexityTask): class PilePileCc(PilePerplexityTask):
PILE_SET_NAME = "Pile-CC" DATASET_NAME = "pile_pile-cc"
class PilePubmedAbstracts(PilePerplexityTask): class PilePubmedAbstracts(PilePerplexityTask):
PILE_SET_NAME = "PubMed Abstracts" DATASET_NAME = "pile_pubmed-abstracts"
class PilePubmedCentral(PilePerplexityTask): class PilePubmedCentral(PilePerplexityTask):
PILE_SET_NAME = "PubMed Central" DATASET_NAME = "pile_pubmed-central"
class PileStackExchange(PilePerplexityTask): class PileStackExchange(PilePerplexityTask):
PILE_SET_NAME = "StackExchange" DATASET_NAME = "pile_stackexchange"
class PileUspto(PilePerplexityTask): class PileUspto(PilePerplexityTask):
PILE_SET_NAME = "USPTO Backgrounds" DATASET_NAME = "pile_upsto"
class PileUbuntuIrc(PilePerplexityTask): class PileUbuntuIrc(PilePerplexityTask):
PILE_SET_NAME = "Ubuntu IRC" DATASET_NAME = "pile_ubuntu-irc"
class PileWikipedia(PilePerplexityTask): class PileWikipedia(PilePerplexityTask):
PILE_SET_NAME = "Wikipedia (en)" DATASET_NAME = "pile_wikipedia"
class PileYoutubeSubtitles(PilePerplexityTask): class PileYoutubeSubtitles(PilePerplexityTask):
PILE_SET_NAME = "YoutubeSubtitles" DATASET_NAME = "pile_youtubesubtitles"
...@@ -9,10 +9,7 @@ actually learning about the world? ...@@ -9,10 +9,7 @@ actually learning about the world?
Homepage: https://yonatanbisk.com/piqa/ Homepage: https://yonatanbisk.com/piqa/
""" """
import numpy as np from lm_eval.base import MultipleChoiceTask
from lm_eval.base import MultipleChoiceTask, rf
from ..metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -29,7 +26,7 @@ _CITATION = """ ...@@ -29,7 +26,7 @@ _CITATION = """
""" """
class PiQA(HFTask, MultipleChoiceTask): class PiQA(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "piqa" DATASET_PATH = "piqa"
DATASET_NAME = None DATASET_NAME = None
...@@ -43,7 +40,15 @@ class PiQA(HFTask, MultipleChoiceTask): ...@@ -43,7 +40,15 @@ class PiQA(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
out_doc = { out_doc = {
"goal": doc["goal"], "goal": doc["goal"],
"choices": [doc["sol1"], doc["sol2"]], "choices": [doc["sol1"], doc["sol2"]],
...@@ -53,3 +58,9 @@ class PiQA(HFTask, MultipleChoiceTask): ...@@ -53,3 +58,9 @@ class PiQA(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc["goal"] + "\nAnswer:" return "Question: " + doc["goal"] + "\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["goal"]
...@@ -15,7 +15,6 @@ have been trained on data not specifically collected to succeed on PROST." ...@@ -15,7 +15,6 @@ have been trained on data not specifically collected to succeed on PROST."
Homepage: https://github.com/nala-cub/prost Homepage: https://github.com/nala-cub/prost
""" """
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -36,7 +35,7 @@ _CITATION = """ ...@@ -36,7 +35,7 @@ _CITATION = """
""" """
class PROST(HFTask, MultipleChoiceTask): class PROST(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "corypaik/prost" DATASET_PATH = "corypaik/prost"
DATASET_NAME = None DATASET_NAME = None
...@@ -50,6 +49,9 @@ class PROST(HFTask, MultipleChoiceTask): ...@@ -50,6 +49,9 @@ class PROST(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.' assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context( return super().fewshot_context(
...@@ -59,7 +61,7 @@ class PROST(HFTask, MultipleChoiceTask): ...@@ -59,7 +61,7 @@ class PROST(HFTask, MultipleChoiceTask):
description=description description=description
) )
def _convert_standard(self, doc): def _process_doc(self, doc):
out_doc = { out_doc = {
"query": f"{doc['context']}\nQuestion: {doc['ex_question']}\nAnswer:", "query": f"{doc['context']}\nQuestion: {doc['ex_question']}\nAnswer:",
"choices": [doc['A'], doc['B'], doc['C'], doc['D']], "choices": [doc['A'], doc['B'], doc['C'], doc['D']],
...@@ -69,3 +71,9 @@ class PROST(HFTask, MultipleChoiceTask): ...@@ -69,3 +71,9 @@ class PROST(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
...@@ -16,9 +16,8 @@ and (4) a yes/no/maybe answer which summarizes the conclusion. ...@@ -16,9 +16,8 @@ and (4) a yes/no/maybe answer which summarizes the conclusion.
Homepage: https://pubmedqa.github.io/ Homepage: https://pubmedqa.github.io/
""" """
import numpy as np import numpy as np
from .common import HFTask from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean
from ..metrics import mean
_CITATION = """ _CITATION = """
...@@ -32,7 +31,7 @@ _CITATION = """ ...@@ -32,7 +31,7 @@ _CITATION = """
""" """
class Pubmed_QA(HFTask): class Pubmed_QA(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "pubmed_qa" DATASET_PATH = "pubmed_qa"
DATASET_NAME = "pqa_labeled" DATASET_NAME = "pqa_labeled"
...@@ -49,7 +48,7 @@ class Pubmed_QA(HFTask): ...@@ -49,7 +48,7 @@ class Pubmed_QA(HFTask):
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
# HF is labelled as train but its really just for testing # HF is labelled as train but its really just for testing
return self.data["train"] return self.dataset["train"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"]) ctxs = "\n".join(doc["context"]["contexts"])
...@@ -59,6 +58,12 @@ class Pubmed_QA(HFTask): ...@@ -59,6 +58,12 @@ class Pubmed_QA(HFTask):
doc["final_decision"] doc["final_decision"]
) )
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"] + " " + "\n".join(doc["context"]["contexts"])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(doc["final_decision"]) return " {}".format(doc["final_decision"])
......
...@@ -13,9 +13,6 @@ and Entrance Exam. ...@@ -13,9 +13,6 @@ and Entrance Exam.
Homepage: http://nlp.uned.es/clef-qa/repository/qa4mre.php Homepage: http://nlp.uned.es/clef-qa/repository/qa4mre.php
""" """
import os
import xml.etree.ElementTree as ET
from best_download import download_file
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
...@@ -31,35 +28,8 @@ _CITATION = """ ...@@ -31,35 +28,8 @@ _CITATION = """
class QA4MRE(MultipleChoiceTask): class QA4MRE(MultipleChoiceTask):
VERSION = 0 VERSION = 0
YEAR = None DATASET_PATH = "qa4mre"
def download(self): DATASET_NAME = None
year = self.YEAR
lang = "EN"
base_path = (
"http://nlp.uned.es/clef-qa/repository/js/scripts/downloadFile.php?"
"file=/var/www/html/nlp/clef-qa/repository/resources/QA4MRE/"
)
# TODO: add side tasks?
variable_year_path = {
2011: '2011/Training_Data/Goldstandard/',
2012: '2012/Main_Task/Training_Data/Goldstandard/Used_in_Evaluation/',
2013: '2013/Main_Task/Training_Data/Goldstandard/'
}
sha256sums = {
2011 : "6d2524952a3a015f2a82df785b85b5578681e3602ec276b4e72c01f4ebc50034",
2012 : "f9edaf408f8ac93f89a643a0d0b19263a1bb5ce64f19b2af10df279a656dfb24",
2013 : "c60e5aa4ec77e0493ef0b11d46bd1d74d58a499a3a2f871b8cf3af9536f0f094",
}
vpath = variable_year_path[year]
url_path = f"{base_path}{vpath}QA4MRE-{year}-{lang}_GS.xml"
if not os.path.exists("data/qa4mre"):
os.makedirs("data/qa4mre", exist_ok=True)
if not os.path.isfile(f"data/qa4mre/QA4MRE-{year}-{lang}"):
download_file(
url_path,
local_file=f"data/qa4mre/QA4MRE-{year}-{lang}_GS.xml",
expected_checksum=sha256sums[year],
)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -70,39 +40,37 @@ class QA4MRE(MultipleChoiceTask): ...@@ -70,39 +40,37 @@ class QA4MRE(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, question): def test_docs(self):
choices = [i.text for i in question.iter('answer')] # `qa4mre` only has train data so we use it for the test docs.
return map(self._process_doc, self.dataset["train"])
def _process_doc(self, doc):
choices = doc["answer_options"]["answer_str"]
out_doc = { out_doc = {
"query" : question.find('q_str').text, "source": doc["document_str"].strip().replace("\'", "'"),
"choices": choices, "query": doc["question_str"],
"gold" : int(question.find("./answer[@correct='Yes']").attrib["a_id"]) - 1, "choices": choices,
"gold": int(doc["correct_answer_id"]) - 1,
} }
return out_doc return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
# TODO: context is much larger than the context sometimes
# at the moment, it just gets left-truncated by LM automatically, and maybe that's good enough?
for reading_test in root.iter('reading-test'):
src = reading_test[0].text
src = src.strip().replace("\'", "'")
for qid, question in enumerate(reading_test.iter('q')):
out_doc = self._convert_standard(question)
out_doc['source'] = src
yield out_doc
def test_docs(self):
return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml")
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]) return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
class QA4MRE_2011(QA4MRE): class QA4MRE_2011(QA4MRE):
YEAR = 2011 DATASET_NAME = "2011.main.EN"
class QA4MRE_2012(QA4MRE): class QA4MRE_2012(QA4MRE):
YEAR = 2012 DATASET_NAME = "2012.main.EN"
class QA4MRE_2013(QA4MRE): class QA4MRE_2013(QA4MRE):
YEAR = 2013 DATASET_NAME = "2013.main.EN"
...@@ -11,13 +11,10 @@ provide supporting evidence to answers. ...@@ -11,13 +11,10 @@ provide supporting evidence to answers.
Homepage: https://allenai.org/data/qasper Homepage: https://allenai.org/data/qasper
""" """
from collections import Counter from collections import Counter
from math import exp
import random
import re import re
import string import string
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import f1_score, mean from lm_eval.metrics import f1_score, mean
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -104,11 +101,20 @@ def token_f1_score(prediction, ground_truth): ...@@ -104,11 +101,20 @@ def token_f1_score(prediction, ground_truth):
return f1 return f1
class QASPER(HFTask): class QASPER(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "qasper" DATASET_PATH = "qasper"
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ( return (
"TITLE: " "TITLE: "
...@@ -130,14 +136,14 @@ class QASPER(HFTask): ...@@ -130,14 +136,14 @@ class QASPER(HFTask):
return " " + answer return " " + answer
def training_docs(self): def training_docs(self):
for doc in self.data["train"]: for doc in self.dataset["train"]:
yield from self.process_doc(doc) yield from self._process_doc(doc)
def validation_docs(self): def validation_docs(self):
for doc in self.data["train"]: for doc in self.dataset["validation"]:
yield from self.process_doc(doc) yield from self._process_doc(doc)
def process_doc(self, doc): def _process_doc(self, doc):
"""Given a `doc`, flatten it out so that each JSON blob """Given a `doc`, flatten it out so that each JSON blob
contains exactly one question and one answer. Logic taken from contains exactly one question and one answer. Logic taken from
the reference implementation available at the reference implementation available at
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment