Unverified Commit 11f614b0 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge branch 'master' into task_doc

parents 0a6a9b7e e00d682f
...@@ -8,7 +8,8 @@ even for highly specialized humans. ...@@ -8,7 +8,8 @@ even for highly specialized humans.
Homepage: https://aghie.github.io/head-qa/ Homepage: https://aghie.github.io/head-qa/
""" """
from . common import HFTask import inspect
import lm_eval.datasets.headqa.headqa
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
...@@ -24,9 +25,9 @@ _CITATION = """ ...@@ -24,9 +25,9 @@ _CITATION = """
""" """
class HeadQABase(HFTask, MultipleChoiceTask): class HeadQABase(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "head_qa" DATASET_PATH = inspect.getfile(lm_eval.datasets.headqa.headqa)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask): ...@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
out_doc = { out_doc = {
"id": doc["qid"], "id": doc["qid"],
"query": "Question: " + doc["qtext"] + "\nAnswer:", "query": "Question: " + doc["qtext"] + "\nAnswer:",
...@@ -49,16 +61,25 @@ class HeadQABase(HFTask, MultipleChoiceTask): ...@@ -49,16 +61,25 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
class HeadQAEn(HeadQABase): class HeadQAEn(HeadQABase):
DATASET_NAME = "en" DATASET_NAME = "en"
class HeadQAEs(HeadQABase): class HeadQAEs(HeadQABase):
DATASET_NAME = "es" DATASET_NAME = "es"
# for backwards compatibility # for backwards compatibility
class HeadQAEsDeprecated(HeadQABase): class HeadQAEsDeprecated(HeadQABase):
DATASET_NAME = "es" DATASET_NAME = "es"
def __init__(self): def __init__(self):
super().__init__() super().__init__()
print("WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info.") print("WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info.")
\ No newline at end of file
...@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/ ...@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/
""" """
import re import re
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -28,7 +27,7 @@ _CITATION = """ ...@@ -28,7 +27,7 @@ _CITATION = """
""" """
class HellaSwag(HFTask, MultipleChoiceTask): class HellaSwag(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "hellaswag" DATASET_PATH = "hellaswag"
DATASET_NAME = None DATASET_NAME = None
...@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
@classmethod def training_docs(self):
def preprocess(cls, text): if self._training_docs is None:
text = text.strip() self._training_docs = list(map(self._process_doc, self.dataset["train"]))
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. return self._training_docs
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text) def validation_docs(self):
text = text.replace(" ", " ") return map(self._process_doc, self.dataset["validation"])
return text
def _convert_standard(self, doc): def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = { out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx), "query": self.preprocess(doc['activity_label'] + ': ' + ctx),
...@@ -60,5 +58,20 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -60,5 +58,20 @@ class HellaSwag(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
This diff is collapsed.
...@@ -8,13 +8,10 @@ models to generate answer derivations and explanations. ...@@ -8,13 +8,10 @@ models to generate answer derivations and explanations.
Homepage: https://github.com/hendrycks/math Homepage: https://github.com/hendrycks/math
""" """
import abc import inspect
import json import lm_eval.datasets.hendrycks_math.hendrycks_math
from lm_eval.utils import sh
from lm_eval.metrics import mean from lm_eval.metrics import mean
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from pathlib import Path
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -28,21 +25,8 @@ _CITATION = """ ...@@ -28,21 +25,8 @@ _CITATION = """
class Math(Task): class Math(Task):
DATASET_PATH = Path('data/MATH') DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math)
DATASET_NAME = None
def download(self):
if not (self.DATASET_PATH / 'test').exists() or not (self.DATASET_PATH / 'done').exists():
sh(f"mkdir -p {self.DATASET_PATH}")
download_file("https://people.eecs.berkeley.edu/~hendrycks/MATH.tar", local_file=f"{self.DATASET_PATH}.tar", expected_checksum="0fbe4fad0df66942db6c221cdcc95b298cc7f4595a2f0f518360cce84e90d9ac")
sh(f"""
tar -xf {self.DATASET_PATH}.tar -C data/ && touch {self.DATASET_PATH / 'done'}
rm {self.DATASET_PATH}.tar
""")
@abc.abstractmethod
def get_file_info(self):
"""returns directory name"""
pass
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -53,28 +37,31 @@ class Math(Task): ...@@ -53,28 +37,31 @@ class Math(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _load_docs(self, path):
for file in sorted(path.iterdir()):
with open(file) as f:
doc = json.load(f)
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
yield doc
def training_docs(self): def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train" / self.get_file_info()) return map(self._process_doc, self.dataset["train"])
def validation_docs(self): def validation_docs(self):
return NotImplemented return NotImplemented
def test_docs(self): def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test" / self.get_file_info()) return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Problem: " + doc["problem"] + "\nAnswer:" return "Problem: " + doc["problem"] + "\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["problem"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc["answer"] return " " + doc["solution"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, ["\n"])
...@@ -301,41 +288,34 @@ class Math(Task): ...@@ -301,41 +288,34 @@ class Math(Task):
class MathAlgebra(Math): class MathAlgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'algebra'
return 'algebra'
class MathCountingAndProbability(Math): class MathCountingAndProbability(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'counting_and_probability'
return 'counting_and_probability'
class MathGeometry(Math): class MathGeometry(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'geometry'
return 'geometry'
class MathIntermediateAlgebra(Math): class MathIntermediateAlgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'intermediate_algebra'
return 'intermediate_algebra'
class MathNumberTheory(Math): class MathNumberTheory(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'number_theory'
return 'number_theory'
class MathPrealgebra(Math): class MathPrealgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'prealgebra'
return 'prealgebra'
class MathPrecalculus(Math): class MathPrecalculus(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'precalculus'
return 'precalculus'
...@@ -12,12 +12,7 @@ important shortcomings. ...@@ -12,12 +12,7 @@ important shortcomings.
Homepage: https://github.com/hendrycks/test Homepage: https://github.com/hendrycks/test
""" """
import csv
import random
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from ..utils import sh
from pathlib import Path
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -61,25 +56,15 @@ def create_task(subject): ...@@ -61,25 +56,15 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask): class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/hendrycksTest/") DATASET_PATH = "hendrycks_test"
DATASET_NAME = None
def __init__(self, subject): def __init__(self, subject):
self.subject = subject self.DATASET_NAME = subject
super().__init__() super().__init__()
def download(self):
if not (self.DATASET_PATH / 'done').exists():
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/data.tar", local_file="data/data.tar", expected_checksum="78a804365a59028188fb19bd1adcadc5e0c260b220a9d8b2e33a5ea7d5fbe3b4")
sh("""
tar -xf data/data.tar -C data/
rm data/data.tar
mv data/data data/hendrycksTest
touch data/hendrycksTest/done
""")
def has_training_docs(self): def has_training_docs(self):
return True return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def validation_docs(self):
def format_example(doc, choices): return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, keys):
""" """
Question: <prompt> Question: <prompt>
Choices: Choices:
...@@ -98,46 +89,31 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -98,46 +89,31 @@ class GeneralHendrycksTest(MultipleChoiceTask):
D. <choice4> D. <choice4>
Answer: Answer:
""" """
prompt = "Question: " + doc[0] + "\nChoices:\n" prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join([f"{choices[j]}. {doc[j+1]}\n" for j in range(4)]) prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
prompt += "Answer:" prompt += "Answer:"
return prompt return prompt
choices = ['A', 'B', 'C', 'D'] keys = ['A', 'B', 'C', 'D']
return { return {
"query": format_example(doc, choices), "query": format_example(doc, keys),
"choices": doc[1:5], "choices": doc["choices"],
"gold": choices.index(doc[5]) "gold": keys.index(doc["answer"]) if isinstance(doc["answer"], str) else doc["answer"]
} }
def _load_docs(self, filename):
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
return (self._convert_standard(doc) for doc in reader)
def training_docs(self):
docs = []
for train_dir in ["auxiliary_train", "dev"]:
for f in (self.DATASET_PATH / train_dir).iterdir():
docs.extend(self._load_docs(f))
return docs
def validation_docs(self):
filename = self.DATASET_PATH / "val" / f"{self.subject}_val.csv"
return self._load_docs(filename)
def test_docs(self):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is # fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't # in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(self._load_docs(filename)) self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k) return rnd.sample(list(self._fewshot_docs), k)
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
This diff is collapsed.
...@@ -13,12 +13,7 @@ in the broader discourse. ...@@ -13,12 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from lm_eval.tasks.lambada import LAMBADA from lm_eval.tasks.lambada import LAMBADA
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -35,8 +30,15 @@ _CITATION = """ ...@@ -35,8 +30,15 @@ _CITATION = """
class LAMBADA_cloze(LAMBADA): class LAMBADA_cloze(LAMBADA):
VERSION = 0 VERSION = 0
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] + " ____. ->" return doc['text'].rsplit(' ', 1)[0] + " ____. ->"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['text']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['text'].rsplit(' ', 1)[1] return " " + doc['text'].rsplit(' ', 1)[1]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment