Unverified Commit e26dc4d3 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #144 from jon-tow/arc-refactor

Refactor `ARC` as a `MultipleChoiceTask`
parents f7992789 24ac76df
import numpy as np from lm_eval.base import MultipleChoiceTask
from lm_eval.base import rf, mean from .common import HFTask
from . common import HFTask
class ARCEasy(HFTask): class ARCEasy(HFTask, MultipleChoiceTask):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy" DATASET_NAME = "ARC-Easy"
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
def __init__(self):
super().__init__()
self.data = self.__clean_data()
def __clean_data(self):
""" Resolves various edge cases in the unprocessed HF ARC dataset. """
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {'1': 'A', '2': 'B', '3': 'C', '4': 'D', '5': 'E'}
result = {}
for split, data in self.data.items():
result[split] = []
for doc in data:
# Ensure all `answerKey`s and `label`s are in letter format.
doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
doc["choices"]["label"] = [
num_to_letter.get(label, label) for label in doc["choices"]["label"]
]
result[split].append(doc)
return result
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -39,68 +15,41 @@ class ARCEasy(HFTask): ...@@ -39,68 +15,41 @@ class ARCEasy(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self): def _convert_standard(self, doc):
# TODO: figure out description # NOTE: Some `doc["answerKey"]`s are in numeric string format being one
return "" # of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
def doc_to_text(self, doc): doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
return "Question: " + doc['question'] + '\nAnswer:' out_doc = {
"id": doc["id"],
def doc_to_target(self, doc): "query": "Question: " + doc["question"] + "\nAnswer:",
index = self.letter_to_num[doc["answerKey"]] "choices": doc["choices"]["text"],
return " " + doc['choices']['text'][index] "gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]),
}
return out_doc
def construct_requests(self, doc, ctx): def _load_docs(self, docs):
""" Uses RequestFactory to construct Requests and returns an iterable of for record in docs:
Requests which will be sent to the LM. yield self._convert_standard(record)
:param doc: def training_docs(self):
The document as returned from training_docs, validation_docs, or test_docs. docs = super().training_docs()
:param ctx: str return self._load_docs(docs)
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_choices = []
for choice in doc["choices"]["text"]:
ll_choices.append(rf.loglikelihood(ctx, " " + choice)[0])
return ll_choices
def process_results(self, doc, results): def validation_docs(self):
"""Take a single document and the LM results and evaluates, returning a docs = super().validation_docs()
dict where keys are the names of submetrics and values are the values of return self._load_docs(docs)
the metric for that one document
:param doc: def test_docs(self):
The document as returned from training_docs, validation_docs, or test_docs. docs = super().test_docs()
:param results: return self._load_docs(docs)
The results of the requests created in construct_requests.
"""
gold = self.letter_to_num[doc["answerKey"]]
pred = np.argmax(results)
return {
"acc": pred == gold
}
def aggregation(self): def fewshot_description(self):
""" # TODO: figure out description
:returns: {str: [float] -> float} return ""
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
def higher_is_better(self): def doc_to_text(self, doc):
""" return doc["query"]
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment