Commit 386d63ea authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed brier_score to allow multi-gpu inference

parent 4efa0b6d
...@@ -126,8 +126,7 @@ def brier_score(items): # This is a passthrough function ...@@ -126,8 +126,7 @@ def brier_score(items): # This is a passthrough function
for g, p in zip(gold_group.values(), pred_group.values()): for g, p in zip(gold_group.values(), pred_group.values()):
_p = np.array(p) _p = np.array(p)
_g = np.array(g) _g = np.array(g)
_g_one_hot = np.eye(len(_p[0]))[_g] average += np.mean(np.sum((_p - _g) ** 2, axis=1)) * len(_g)
average += np.mean(np.sum((_p - _g_one_hot) ** 2, axis=1)) * len(_g)
total_size += len(_g) total_size += len(_g)
return average / total_size return average / total_size
......
...@@ -1116,7 +1116,8 @@ class ConfigurableTask(Task): ...@@ -1116,7 +1116,8 @@ class ConfigurableTask(Task):
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**( **(
{"brier_score": (gold, prob_norm)} # {"brier_score": (gold, prob_norm)}
{"brier_score": [np.eye(len(prob_norm))[gold], prob_norm]}
if "brier_score" in use_metric if "brier_score" in use_metric
else {} else {}
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment