Unverified Commit 620d6a15 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Fix: task weighting by subtask size ; update Pooled Stderr formula slightly (#1427)

* fix weight_by_size condition

* add tests, update stderr formula slightly

* apply pre-commit
parent bfbd0325
......@@ -4,11 +4,11 @@ import random
from collections.abc import Iterable
from typing import List
import evaluate
import numpy as np
import sacrebleu
import sklearn.metrics
import evaluate
from lm_eval.api.registry import register_aggregation, register_metric
......@@ -436,13 +436,14 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
assert len(stderrs) == len(sizes)
# formula source: https://en.wikipedia.org/wiki/Pooled_variance
# this empirically matches running `stderr_for_metric` on all instances
# and: https://stats.stackexchange.com/a/4841331
# this empirically seems to match running `stderr_for_metric` on all instances
# from the subtasks concatenated with each other.
pooled_sample_var = (
sum([(size - 1) * stderr**2 for size, stderr in zip(sizes, stderrs)])
sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
) / (sum(sizes) - len(sizes))
return np.sqrt(pooled_sample_var)
return np.sqrt(pooled_sample_var / sum(sizes))
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
......@@ -481,7 +482,7 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
# A helper function that is used to aggregate
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
if weight_by_size:
if not weight_by_size:
sizes = [1] * len(sizes)
assert len(metrics) == len(sizes)
......
......@@ -2,7 +2,6 @@ import logging
from typing import Callable, Dict
import evaluate
from lm_eval.api.model import LM
......
......@@ -10,7 +10,7 @@ def main() -> None:
# Removed hy and sk subdataset because the original dataset is broken
# I created this PR https://huggingface.co/datasets/alexandrainst/m_mmlu/discussions/3
# on the dataset for the authors, in case it will be accepeted the filter can be removed
keys_without_hy_sk = list(filter(lambda k: ('hy' not in k and 'sk' not in k),
keys_without_hy_sk = list(filter(lambda k: ('hy' not in k and 'sk' not in k),
datasets.get_dataset_infos(dataset_path).keys()))
for task in tqdm():
......
......@@ -51,4 +51,4 @@ def process_results_mc2(doc, results):
p_true, p_false = np.exp(np.array(ll_true)), np.exp(np.array(ll_false))
p_true = p_true / (sum(p_true) + sum(p_false))
return {"acc": sum(p_true)}
\ No newline at end of file
return {"acc": sum(p_true)}
import itertools
import numpy as np
import pytest
from lm_eval.api.metrics import (
aggregate_subtask_metrics,
mean,
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.utils import (
Collator,
get_rolling_token_windows,
......@@ -299,3 +308,39 @@ class TestCollator:
# check indices
reordered_output = loglikelihoods.get_original(output)
assert reordered_output == [x[1] for x in loglikelihood_samples]
def test_aggregate_mean():
# test weight_by_size is respected
assert (
aggregate_subtask_metrics([0.3, 0.2, 0.4], [20, 40, 100], weight_by_size=False)
== 0.3
)
assert (
aggregate_subtask_metrics([0.3, 0.2, 0.4], [20, 40, 100], weight_by_size=True)
== 0.3375
)
@pytest.mark.parametrize(
"samples",
[
[40 * [1.0] + 60 * [0.0], 30 * [1.0] + 30 * [0.0], 20 * [1.0] + 60 * [0.0]],
[35 * [1.0] + 65 * [0.0], 20 * [1.0] + 20 * [0.0]],
],
)
def test_aggregate_stderrs(samples):
# check that aggregating subtasks' bootstrap stderrs with our formula
# (using weight_by_size) is ~equiv.
# to just getting bootstrap stderr of the whole set of samples
mean_stderr = stderr_for_metric(metric=mean, bootstrap_iters=100000)
stderrs = [mean_stderr(subtask) for subtask in samples]
sizes = [len(subtask) for subtask in samples]
assert np.allclose(
pooled_sample_stderr(stderrs, sizes),
mean_stderr(list(itertools.chain.from_iterable(samples))),
atol=1.0e-3,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment