"llm/vscode:/vscode.git/clone" did not exist on "acfa2b94220882ffe8246a56e35a6cd1db38e6a2"
Commit c8e77519 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Fix `PIQA` prompt

parent 01630657
import json
import random
from lm_eval.base import Task, rf, mean
from ..utils import sh
import os
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
class PiQA(Task):
def download(self):
if not os.path.exists('data/piqa'):
#TODO: use best_download
sh("""
mkdir -p data/piqa
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
wget https://yonatanbisk.com/piqa/data/train-labels.lst -O data/piqa/piqa-train-labels.lst
wget https://yonatanbisk.com/piqa/data/valid.jsonl -O data/piqa/piqa-valid.jsonl
wget https://yonatanbisk.com/piqa/data/valid-labels.lst -O data/piqa/piqa-valid-labels.lst
wget https://yonatanbisk.com/piqa/data/tests.jsonl -O data/piqa/piqa-test.jsonl
""")
class PiQA(HFTask):
DATASET_PATH = "piqa"
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -26,44 +16,25 @@ class PiQA(Task):
def has_test_docs(self):
return False
def load_docs(self, textfilename, labelfilename):
if labelfilename != None:
return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(map(lambda x: x.strip(), open(labelfilename, 'r'))))
else:
return [json.loads(entry) for entry in list(open(textfilename,'r'))]
def training_docs(self):
return self.load_docs('data/piqa/piqa-train.jsonl', 'data/piqa/piqa-train-labels.lst')
def validation_docs(self):
return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst')
#def test_docs(self):
# return self.load_docs('data/piqa/piqa-test.jsonl', None)
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
return doc[0]['goal'] + "\n"
return doc["goal"] + "\n"
def doc_to_target(self, doc):
#TODO: check if oa uses newline
rightanswer = int(doc[1]) + 1
return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
solutions = [doc["sol1"], doc["sol2"]]
return solutions[doc["label"]]
def construct_requests(self, doc, ctx):
ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1'])
ll_2, _ = rf.loglikelihood(ctx, doc[0]['sol2'])
ll_1, _ = rf.loglikelihood(ctx, doc['sol1'])
ll_2, _ = rf.loglikelihood(ctx, doc['sol2'])
return ll_1, ll_2
def process_results(self, doc, results):
ll_1, ll_2 = results
def process_results(self, doc, results):
return {
'acc': (ll_1 > ll_2) == (int(doc[1]) == 0)
'acc': np.argmax(results) == doc["label"]
}
def aggregation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment