Commit 57c751fa authored by nicholaskross's avatar nicholaskross
Browse files

got a start on the SAT eval

parent 6803e647
...@@ -4,6 +4,9 @@ import json ...@@ -4,6 +4,9 @@ import json
import random import random
import os import os
from lm_eval.base import Dataset from lm_eval.base import Dataset
from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric
import numpy as np
from ..utils import sh from ..utils import sh
...@@ -94,5 +97,20 @@ class SATAnalogies(Dataset): ...@@ -94,5 +97,20 @@ class SATAnalogies(Dataset):
return text return text
def evaluate(self, docs, lm): def evaluate(self, docs, lm):
# TODO: Write evaluation function golds = [doc["answer_key"] for doc in docs]
raise NotImplementedError() preds = []
for doc in tqdm_lib.tqdm(docs):
ctx = self.fewshot_context(
doc=doc,
num_fewshot=1,
provide_description=None,
# unless Dataset evaluate()s should get num_fewshot/ provide_description
)
probs_before_numpy = []
for choice in doc["choices"]:
this_choice = " " + choice
probs_before_numpy.append(lm.loglikelihood(ctx, this_choice))
probs = np.array(probs_before_numpy)
preds.append(np.argmax(probs))
return simple_accuracy_metric(preds=preds, golds=golds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment