Unverified Commit 93510e3a authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #80 from nicholaskross/master

Started SAT eval
parents afc614fe 515e0470
......@@ -3,7 +3,10 @@
import json
import random
import os
from lm_eval.base import Dataset
from lm_eval.base import Dataset, rf, mean
from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric
import numpy as np
from ..utils import sh
......@@ -93,8 +96,37 @@ class SATAnalogies(Dataset):
return text
# TODO: Implement evaluation code
# ***IMPORTANT***: this evaluation function needs to be written for the new framework.
# For more info, check out the interface in base.py and the example BoolQ implementation in superglue.py.
# Remove this comment when the evaluation code is implemented.
\ No newline at end of file
def doc_to_target(self, doc):
# assumes answer_key is the true-answer's letter
return doc['answer_key']
def construct_requests(self, ctx):
# assumes the output is the predicted-answer's letter
ll_a = rf.loglikelihood(ctx, ' a')
ll_b = rf.loglikelihood(ctx, ' b')
ll_c = rf.loglikelihood(ctx, ' c')
ll_d = rf.loglikelihood(ctx, ' d')
ll_e = rf.loglikelihood(ctx, ' e')
return ll_a, ll_b, ll_c, ll_d, ll_e
def process_results(self, doc, results):
predicted_odds = np.array(list(results))
gold = doc["answer_key"]
acc = 1. if np.argmax(predicted_odds) == gold else 0.
return [
{
"submetric": "acc",
"value": acc,
"higher_is_better": True,
"aggregation": mean
}
]
def evaluate(self, docs, lm):
# functionality already implemented above
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment