"vscode:/vscode.git/clone" did not exist on "d6bfe2a937b1438f9785a77823350e83a6408831"
Commit 50a8ddfb authored by Herbie Bradley's avatar Herbie Bradley
Browse files

Allow sample logging of calibration data

parent a6bd7126
......@@ -130,7 +130,6 @@ def ece_fn(items): # This is a passthrough function
This consists of the average absolute difference between the fraction of
model predictions which are correct and the mean of the model's normalized
probability for those predictions (after binning), for multiple choice questions.
Lower is better.
Paper: https://arxiv.org/abs/2207.05221
"""
......
......@@ -966,14 +966,14 @@ class ConfigurableTask(Task):
if "ece" in use_metric:
# Convert lls from log-probabilities to normalized probabilities
norm_probs: np.ndarray = np.exp(lls - sp.logsumexp(lls))
calib_scores: np.ndarray = np.zeros(len(choices))
norm_probs: list[float] = np.exp(lls - sp.logsumexp(lls)).tolist()
calib_scores: list[float] = [0.0] * len(choices)
if isinstance(gold, list):
for g in gold:
calib_scores[g] = 1.0
else:
calib_scores[gold] = 1.0
calibration_probs: dict[str, np.ndarray] = {
calibration_probs: dict[str, list[float]] = {
"probs": norm_probs,
"scores": calib_scores,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment