Unverified Commit 4462e415 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Make piqa inherit MultipleChoiceTask

parent 2522c27f
import numpy as np
from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask
from . common import MultipleChoiceTask, HFTask
class PiQA(HFTask):
class PiQA(HFTask, MultipleChoiceTask):
DATASET_PATH = "piqa"
DATASET_NAME = None
......@@ -21,32 +21,29 @@ class PiQA(HFTask):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
return "Question: "+doc["goal"] + "\nAnswer:"
def _convert_standard(self, doc):
out_doc = {
"goal": doc["goal"],
"choices": [doc["sol1"], doc["sol2"]],
"gold": doc["label"],
}
return out_doc
def doc_to_target(self, doc):
solutions = [doc["sol1"], doc["sol2"]]
return " " + solutions[doc["label"]]
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def construct_requests(self, doc, ctx):
ll_1, _ = rf.loglikelihood(ctx, " " + doc['sol1'])
ll_2, _ = rf.loglikelihood(ctx, " " + doc['sol2'])
return ll_1, ll_2
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def process_results(self, doc, results):
completion_len = np.array([float(len(doc["sol1"])), float(len(doc["sol2"]))])
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
return {
'acc': np.argmax(results) == doc["label"],
'acc_norm': np.argmax(results / completion_len) == doc["label"]
}
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def aggregation(self):
return {
'acc': mean
}
def higher_is_better(self):
return {
'acc': True
}
def doc_to_text(self, doc):
return "Question: " + doc["goal"] + "\nAnswer:"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment