Commit 98f9bac9 authored by lintangsutawika's avatar lintangsutawika
Browse files

add brier_score

parent 3533e4b9
......@@ -7,6 +7,7 @@ import sklearn.metrics
import random
import evaluate
from Levenshtein import distance
from lm_eval.api.registry import register_metric, register_aggregation
......@@ -106,6 +107,27 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score
@register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function
gold = list(zip(*items))[0]
gold_one_hot = np.eye(max(gold)+1)[gold]
predictions = list(zip(*items))[1]
print("predictions", prediction)
print("gold_one_hot", gold_one_hot)
import sys; sys.exit()
return np.mean(np.sum((predictions - gold_one_hot)**2, axis=1))
@register_metric(
metric="brier_score",
higher_is_better=False,
output_type=["multiple_choice"],
aggregation="brier_score",
)
def brier_score_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc",
higher_is_better=True,
......@@ -139,6 +161,18 @@ def acc_mutual_info_fn(items): # This is a passthrough function
exact_match = evaluate.load("exact_match")
# @register_metric(
# metric="token_edit_distance",
# higher_is_better=False,
# output_type=["generate_until"],
# aggregation="mean",
# )
# def ted_fn(items): # This is a passthrough function
# references, predictions = items
# return distance(references, predictions)
@register_metric(
metric="exact_match",
higher_is_better=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment