Commit 4acb339e authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed brier score to accomodate samples with different number of choices

parent 835cc40e
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from collections import defaultdict
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
...@@ -111,13 +111,26 @@ def ter(items): ...@@ -111,13 +111,26 @@ def ter(items):
@register_aggregation("brier_score") @register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function def brier_score(items): # This is a passthrough function
gold, predictions = list(zip(*items))
print(type(predictions)) # Certain datasets like arc_easy can have a different number of choices.
predictions = np.array(predictions) golds, predictions = list(zip(*items))
print(predictions.shape)
gold = np.array(gold) pred_group = defaultdict(list)
gold_one_hot = np.eye(len(predictions[0]))[gold] gold_group = defaultdict(list)
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1)) for gold, pred in zip(golds, predictions):
pred_group[len(pred)].append(pred)
gold_group[len(pred)].append(gold)
total_size = 0
average = 0
for g, p in zip(gold_group.values(), pred_group.values()):
_p = np.array(p)
_g = np.array(g)
_g_one_hot = np.eye(len(_p[0]))[_g]
average += np.mean(np.sum((_p - _g_one_hot) ** 2, axis=1))*len(_g)
total_size += len(_g)
return average/total_size
@register_metric( @register_metric(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment