Commit 386d63ea authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed brier_score to allow multi-gpu inference

parent 4efa0b6d
......@@ -126,8 +126,7 @@ def brier_score(items): # This is a passthrough function
for g, p in zip(gold_group.values(), pred_group.values()):
_p = np.array(p)
_g = np.array(g)
_g_one_hot = np.eye(len(_p[0]))[_g]
average += np.mean(np.sum((_p - _g_one_hot) ** 2, axis=1)) * len(_g)
average += np.mean(np.sum((_p - _g) ** 2, axis=1)) * len(_g)
total_size += len(_g)
return average / total_size
......
......@@ -1116,7 +1116,8 @@ class ConfigurableTask(Task):
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
# {"brier_score": (gold, prob_norm)}
{"brier_score": [np.eye(len(prob_norm))[gold], prob_norm]}
if "brier_score" in use_metric
else {}
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment