Unverified Commit 3e2e6d82 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #497 from ollmer/mmlu_fix

MMLU task fix
parents 2b70a7c9 5ba0c5f9
......@@ -14,7 +14,6 @@ Homepage: https://github.com/hendrycks/test
"""
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding},
......@@ -103,8 +102,8 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "hendrycks_test"
VERSION = 1
DATASET_PATH = "cais/mmlu"
DATASET_NAME = None
def __init__(self, subject):
......@@ -112,7 +111,7 @@ class GeneralHendrycksTest(MultipleChoiceTask):
super().__init__()
def has_training_docs(self):
return False
return True
def has_validation_docs(self):
return True
......@@ -126,41 +125,50 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _format_subject(self, subject):
words = subject.split("_")
return " ".join(words)
def fewshot_context(self, doc, num_fewshot, **kwargs):
subject = self.DATASET_NAME
description = f"The following are multiple choice questions (with answers) about {self._format_subject(subject)}."
kwargs["description"] = description
return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs)
def _process_doc(self, doc):
def format_example(doc, keys):
"""
Question: <prompt>
Choices:
<prompt>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join(
question = doc["question"].strip()
choices = "".join(
[f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]
)
prompt += "Answer:"
prompt = f"{question}\n{choices}Answer:"
return prompt
keys = ["A", "B", "C", "D"]
return {
"query": format_example(doc, keys),
"choices": doc["choices"],
"gold": keys.index(doc["answer"])
if isinstance(doc["answer"], str)
else doc["answer"],
"choices": keys,
"gold": doc["answer"],
}
def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k)
# use the unchanged order of the dev set without sampling,
# just as in the original code https://github.com/hendrycks/test/blob/master/evaluate.py#L28
return self._fewshot_docs[:k]
def doc_to_text(self, doc):
return doc["query"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment