Unverified Commit 11f614b0 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge branch 'master' into task_doc

parents 0a6a9b7e e00d682f
......@@ -8,7 +8,8 @@ even for highly specialized humans.
Homepage: https://aghie.github.io/head-qa/
"""
from . common import HFTask
import inspect
import lm_eval.datasets.headqa.headqa
from lm_eval.base import MultipleChoiceTask
......@@ -24,9 +25,9 @@ _CITATION = """
"""
class HeadQABase(HFTask, MultipleChoiceTask):
class HeadQABase(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "head_qa"
DATASET_PATH = inspect.getfile(lm_eval.datasets.headqa.headqa)
def has_training_docs(self):
return True
......@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return True
def _convert_standard(self, doc):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
out_doc = {
"id": doc["qid"],
"query": "Question: " + doc["qtext"] + "\nAnswer:",
......@@ -49,16 +61,25 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
class HeadQAEn(HeadQABase):
DATASET_NAME = "en"
class HeadQAEs(HeadQABase):
DATASET_NAME = "es"
# for backwards compatibility
class HeadQAEsDeprecated(HeadQABase):
DATASET_NAME = "es"
def __init__(self):
super().__init__()
print("WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info.")
\ No newline at end of file
print("WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info.")
......@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/
"""
import re
from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """
......@@ -28,7 +27,7 @@ _CITATION = """
"""
class HellaSwag(HFTask, MultipleChoiceTask):
class HellaSwag(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "hellaswag"
DATASET_NAME = None
......@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return False
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx),
......@@ -60,5 +58,20 @@ class HellaSwag(HFTask, MultipleChoiceTask):
}
return out_doc
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
This diff is collapsed.
......@@ -8,13 +8,10 @@ models to generate answer derivations and explanations.
Homepage: https://github.com/hendrycks/math
"""
import abc
import json
from lm_eval.utils import sh
import inspect
import lm_eval.datasets.hendrycks_math.hendrycks_math
from lm_eval.metrics import mean
from lm_eval.base import Task, rf
from pathlib import Path
from best_download import download_file
_CITATION = """
......@@ -28,21 +25,8 @@ _CITATION = """
class Math(Task):
DATASET_PATH = Path('data/MATH')
def download(self):
if not (self.DATASET_PATH / 'test').exists() or not (self.DATASET_PATH / 'done').exists():
sh(f"mkdir -p {self.DATASET_PATH}")
download_file("https://people.eecs.berkeley.edu/~hendrycks/MATH.tar", local_file=f"{self.DATASET_PATH}.tar", expected_checksum="0fbe4fad0df66942db6c221cdcc95b298cc7f4595a2f0f518360cce84e90d9ac")
sh(f"""
tar -xf {self.DATASET_PATH}.tar -C data/ && touch {self.DATASET_PATH / 'done'}
rm {self.DATASET_PATH}.tar
""")
@abc.abstractmethod
def get_file_info(self):
"""returns directory name"""
pass
DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -53,28 +37,31 @@ class Math(Task):
def has_test_docs(self):
return True
def _load_docs(self, path):
for file in sorted(path.iterdir()):
with open(file) as f:
doc = json.load(f)
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
yield doc
def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train" / self.get_file_info())
return map(self._process_doc, self.dataset["train"])
def validation_docs(self):
return NotImplemented
def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test" / self.get_file_info())
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
return doc
def doc_to_text(self, doc):
return "Problem: " + doc["problem"] + "\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["problem"]
def doc_to_target(self, doc):
return " " + doc["answer"]
return " " + doc["solution"]
def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"])
......@@ -301,41 +288,34 @@ class Math(Task):
class MathAlgebra(Math):
VERSION = 1
def get_file_info(self):
return 'algebra'
DATASET_NAME = 'algebra'
class MathCountingAndProbability(Math):
VERSION = 1
def get_file_info(self):
return 'counting_and_probability'
DATASET_NAME = 'counting_and_probability'
class MathGeometry(Math):
VERSION = 1
def get_file_info(self):
return 'geometry'
DATASET_NAME = 'geometry'
class MathIntermediateAlgebra(Math):
VERSION = 1
def get_file_info(self):
return 'intermediate_algebra'
DATASET_NAME = 'intermediate_algebra'
class MathNumberTheory(Math):
VERSION = 1
def get_file_info(self):
return 'number_theory'
DATASET_NAME = 'number_theory'
class MathPrealgebra(Math):
VERSION = 1
def get_file_info(self):
return 'prealgebra'
DATASET_NAME = 'prealgebra'
class MathPrecalculus(Math):
VERSION = 1
def get_file_info(self):
return 'precalculus'
DATASET_NAME = 'precalculus'
......@@ -12,12 +12,7 @@ important shortcomings.
Homepage: https://github.com/hendrycks/test
"""
import csv
import random
from lm_eval.base import MultipleChoiceTask
from ..utils import sh
from pathlib import Path
from best_download import download_file
_CITATION = """
......@@ -61,25 +56,15 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = Path("data/hendrycksTest/")
DATASET_PATH = "hendrycks_test"
DATASET_NAME = None
def __init__(self, subject):
self.subject = subject
self.DATASET_NAME = subject
super().__init__()
def download(self):
if not (self.DATASET_PATH / 'done').exists():
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/data.tar", local_file="data/data.tar", expected_checksum="78a804365a59028188fb19bd1adcadc5e0c260b220a9d8b2e33a5ea7d5fbe3b4")
sh("""
tar -xf data/data.tar -C data/
rm data/data.tar
mv data/data data/hendrycksTest
touch data/hendrycksTest/done
""")
def has_training_docs(self):
return True
return False
def has_validation_docs(self):
return True
......@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def has_test_docs(self):
return True
def _convert_standard(self, doc):
def format_example(doc, choices):
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, keys):
"""
Question: <prompt>
Choices:
......@@ -98,46 +89,31 @@ class GeneralHendrycksTest(MultipleChoiceTask):
D. <choice4>
Answer:
"""
prompt = "Question: " + doc[0] + "\nChoices:\n"
prompt += "".join([f"{choices[j]}. {doc[j+1]}\n" for j in range(4)])
prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
prompt += "Answer:"
return prompt
choices = ['A', 'B', 'C', 'D']
keys = ['A', 'B', 'C', 'D']
return {
"query": format_example(doc, choices),
"choices": doc[1:5],
"gold": choices.index(doc[5])
"query": format_example(doc, keys),
"choices": doc["choices"],
"gold": keys.index(doc["answer"]) if isinstance(doc["answer"], str) else doc["answer"]
}
def _load_docs(self, filename):
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
return (self._convert_standard(doc) for doc in reader)
def training_docs(self):
docs = []
for train_dir in ["auxiliary_train", "dev"]:
for f in (self.DATASET_PATH / train_dir).iterdir():
docs.extend(self._load_docs(f))
return docs
def validation_docs(self):
filename = self.DATASET_PATH / "val" / f"{self.subject}_val.csv"
return self._load_docs(filename)
def test_docs(self):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
if self._fewshot_docs is None:
self._fewshot_docs = list(self._load_docs(filename))
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k)
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
This diff is collapsed.
......@@ -13,12 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from lm_eval.tasks.lambada import LAMBADA
from best_download import download_file
_CITATION = """
......@@ -35,8 +30,15 @@ _CITATION = """
class LAMBADA_cloze(LAMBADA):
VERSION = 0
def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] + " ____. ->"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['text']
def doc_to_target(self, doc):
return " " + doc['text'].rsplit(' ', 1)[1]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment