Commit 57c751fa authored by nicholaskross's avatar nicholaskross
Browse files

got a start on the SAT eval

parent 6803e647
......@@ -4,6 +4,9 @@ import json
import random
import os
from lm_eval.base import Dataset
from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric
import numpy as np
from ..utils import sh
......@@ -94,5 +97,20 @@ class SATAnalogies(Dataset):
return text
def evaluate(self, docs, lm):
# TODO: Write evaluation function
raise NotImplementedError()
golds = [doc["answer_key"] for doc in docs]
preds = []
for doc in tqdm_lib.tqdm(docs):
ctx = self.fewshot_context(
doc=doc,
num_fewshot=1,
provide_description=None,
# unless Dataset evaluate()s should get num_fewshot/ provide_description
)
probs_before_numpy = []
for choice in doc["choices"]:
this_choice = " " + choice
probs_before_numpy.append(lm.loglikelihood(ctx, this_choice))
probs = np.array(probs_before_numpy)
preds.append(np.argmax(probs))
return simple_accuracy_metric(preds=preds, golds=golds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment