"vscode:/vscode.git/clone" did not exist on "6af99c553d5d2b54ec6f5235890cf827fa0d1420"
Commit 48c6bd65 authored by Oleh Shliazhko's avatar Oleh Shliazhko
Browse files

fix mmlu task, set updated dataset name and make the prompt identical to the original eval code

parent d1451679
......@@ -14,7 +14,6 @@ Homepage: https://github.com/hendrycks/test
"""
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding},
......@@ -104,7 +103,7 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "hendrycks_test"
DATASET_PATH = "cais/mmlu"
DATASET_NAME = None
def __init__(self, subject):
......@@ -112,7 +111,7 @@ class GeneralHendrycksTest(MultipleChoiceTask):
super().__init__()
def has_training_docs(self):
return False
return True
def has_validation_docs(self):
return True
......@@ -126,41 +125,45 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, **kwargs):
subject = self.DATASET_NAME
description = f"The following are multiple choice questions (with answers) about {subject}."
kwargs["description"] = description
return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs)
def _process_doc(self, doc):
def format_example(doc, keys):
"""
Question: <prompt>
Choices:
<prompt>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join(
question = doc["question"]
choices = "".join(
[f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]
)
prompt += "Answer:"
prompt = f"{question}\n{choices}Answer:"
return prompt
keys = ["A", "B", "C", "D"]
return {
"query": format_example(doc, keys),
"choices": doc["choices"],
"gold": keys.index(doc["answer"])
if isinstance(doc["answer"], str)
else doc["answer"],
"choices": keys,
"gold": doc["answer"],
}
return result
def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k)
return self._fewshot_docs[:k] # rnd.sample(list(self._fewshot_docs), k)
def doc_to_text(self, doc):
return doc["query"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment