"vscode:/vscode.git/clone" did not exist on "56a5442a51241f9f8e44d05db839c5f2095645cf"
Commit 2a1d7d87 authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness

parents b1f7284e a55a5c52
...@@ -47,8 +47,8 @@ TASK_REGISTRY = { ...@@ -47,8 +47,8 @@ TASK_REGISTRY = {
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
#"triviaqa": triviaqa.TriviaQA, #"triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy, # not implemented yet "arc_easy": arc.ARCEasy,
# "arc_challenge": arc.ARCChallenge, # not implemented yet "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"hellaswag": hellaswag.HellaSwag, # not implemented yet "hellaswag": hellaswag.HellaSwag, # not implemented yet
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet # "openbookqa": openbookqa.OpenBookQA, # not implemented yet
......
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask from . common import HFTask
class ARCEasy(HFTask): class ARCEasy(HFTask):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy" DATASET_NAME = "ARC-Easy"
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
def __init__(self):
super().__init__()
self.data = self.__clean_data()
def __clean_data(self):
""" Resolves various edge cases in the unprocessed HF ARC dataset. """
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {'1': 'A', '2': 'B', '3': 'C', '4': 'D', '5': 'E'}
result = {}
for split, data in self.data.items():
result[split] = []
for doc in data:
# Ensure all `answerKey`s and `label`s are in letter format.
doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
doc["choices"]["label"] = [
num_to_letter.get(label, label) for label in doc["choices"]["label"]
]
result[split].append(doc)
return result
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -21,7 +47,8 @@ class ARCEasy(HFTask): ...@@ -21,7 +47,8 @@ class ARCEasy(HFTask):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc['question'] + '\nAnswer:'
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['choices']['text'][doc['choices']['label'].index(doc['answerKey'])] index = self.letter_to_num[doc["answerKey"]]
return " " + doc['choices']['text'][index]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -34,9 +61,11 @@ class ARCEasy(HFTask): ...@@ -34,9 +61,11 @@ class ARCEasy(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. ll_choices = []
raise NotImplementedError('Evaluation not implemented') for choice in doc["choices"]["text"]:
ll_choices.append(rf.loglikelihood(ctx, " " + choice)[0])
return ll_choices
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
...@@ -47,8 +76,11 @@ class ARCEasy(HFTask): ...@@ -47,8 +76,11 @@ class ARCEasy(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold = self.letter_to_num[doc["answerKey"]]
raise NotImplementedError('Evaluation not implemented') pred = np.argmax(results)
return {
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
...@@ -56,8 +88,9 @@ class ARCEasy(HFTask): ...@@ -56,8 +88,9 @@ class ARCEasy(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -65,8 +98,10 @@ class ARCEasy(HFTask): ...@@ -65,8 +98,10 @@ class ARCEasy(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
}
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
......
...@@ -32,7 +32,7 @@ class Arithmetic(Task): ...@@ -32,7 +32,7 @@ class Arithmetic(Task):
self._docs = [self.load_doc(json.loads(line)) for line in jsons] self._docs = [self.load_doc(json.loads(line)) for line in jsons]
def has_training_docs(self): def has_training_docs(self):
return True return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -41,10 +41,10 @@ class Arithmetic(Task): ...@@ -41,10 +41,10 @@ class Arithmetic(Task):
return False return False
def training_docs(self): def training_docs(self):
return self._docs return NotImplemented
def validation_docs(self): def validation_docs(self):
return self._docs[:100] return self._docs
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment