Unverified Commit 94cc1850 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Use Pooled rather than Combined Variance for calculating stderr of task groupings (#1390)

* update formula for stderr aggregation

* hack: see what happens when using stderr_for_metric bootstrapping on a group

* undo bootstrap_for_stderr test

* factor out variance-aggregation formulas into api.metrics

* fix failing tests

* remove stray print

* update comment

* further detail in comment

* add back initialize_tasks() call

* fix format
parent df01adf6
......@@ -2,6 +2,7 @@ import logging
import math
import random
from collections.abc import Iterable
from typing import List
import evaluate
import numpy as np
......@@ -425,3 +426,64 @@ def stderr_for_metric(metric, bootstrap_iters):
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None)
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask.
#
assert len(stderrs) == len(sizes)
# formula source: https://en.wikipedia.org/wiki/Pooled_variance
# this empirically matches running `stderr_for_metric` on all instances
# from the subtasks concatenated with each other.
pooled_sample_var = (
sum([(size - 1) * stderr**2 for size, stderr in zip(sizes, stderrs)])
) / (sum(sizes) - len(sizes))
return np.sqrt(pooled_sample_var)
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
assert (
metrics is not None
), "Need to pass a list of each subtask's metric for this stderr aggregation"
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
# This formula depends on sample means.
# removed because it seems to give erroneously huge stderrs for groupings of tasks
# and does not seem to match up with bootstrap-calculated stderrs for groups.
### don't use this unless a statistician has told you it's the right thing to do ###
# accumulators: we'll aggregate pairwise N - 1 times
variance = stderrs[0] ** 2
curr_size = sizes[0]
curr_score = metrics[0]
for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
curr_score = ((curr_score * curr_size) + (score * size)) / (
curr_size + size
) # NOTE: this assumes our aggregation fn is "mean"
variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
curr_size + size - 1
) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
curr_score - score
) ** 2
return np.sqrt(variance)
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
# A helper function that is used to aggregate
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
if weight_by_size:
sizes = [1] * len(sizes)
assert len(metrics) == len(sizes)
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
......@@ -515,65 +515,35 @@ def evaluate(
results[task_name][metric + "_stderr" + "," + key] = "N/A"
if bool(results):
for group, task_list in reversed(task_hierarchy.items()):
if task_list == []:
# TODO: No samples when bypass
total_size = results[group].get("samples", 999)
else:
total_size = 0
for task in task_list:
metrics = results[task].copy()
if "alias" in metrics:
metrics.pop("alias")
current_size = metrics.pop("samples")
all_stderr = []
for metric in [
key for key in metrics.keys() if "_stderr" not in key
]:
stderr = "_stderr,".join(metric.split(","))
stderr_score = results[task][stderr]
if stderr_score == "N/A":
var_score = "N/A"
else:
var_score = stderr_score**2
all_stderr.append(stderr)
metric_score = results[task][metric]
if metric in results[group]:
results[group][metric] = (
results[group][metric] * total_size
+ metric_score * current_size
) / (total_size + current_size)
# $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
if var_score == "N/A" or results[group][stderr] == "N/A":
results[group][stderr] = "N/A"
else:
results[group][stderr] = (
(total_size - 1) * results[group][stderr]
+ (current_size - 1) * var_score
) / (
total_size + current_size - 1
) + total_size * current_size / (
(total_size + current_size)
* (total_size + current_size - 1)
) * (
results[group][metric] - metric_score
) ** 2
else:
results[group][metric] = metric_score
results[group][stderr] = var_score
total_size += current_size
for stderr in all_stderr:
results[group][stderr] = np.sqrt(results[group][stderr])
results[group]["samples"] = total_size
for group, task_list in task_hierarchy.items():
if len(task_list) == 0:
# task_hierarchy entries are either
# `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`.
# we only want to operate on groups here.
continue
for metric in [
key for key in results[task_list[0]].keys() if "_stderr" not in key and key not in ["alias", "samples"]
]: # TODO: what if tasks don't all share the same metrics
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [results[task][metric] for task in task_list] # TODO: copy?
stderrs = [results[task][stderr] for task in task_list]
sizes = [results[task]["samples"] for task in task_list]
# compute group's pooled metric and stderr
results[group][metric] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
else:
results[group][stderr] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group]["samples"] = sum(sizes)
def print_tasks(task_hierarchy, results, tab=0):
results_agg = collections.defaultdict(dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment