Unverified Commit 71d0289d authored by Neel Gupta's avatar Neel Gupta Committed by GitHub
Browse files

[FIX] Initial code to disable multi-proc for stderr (#3106)



* [FIX] Initial code to disable multi-proc for stderr

* add docs; align no-mp bootstrap with mp

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent ff41a856
import logging import logging
import math import math
import os
import random import random
import re import re
import string import string
from collections.abc import Iterable from collections.abc import Iterable
from typing import List from typing import Callable, List, Optional, Sequence, TypeVar
import numpy as np import numpy as np
import sacrebleu import sacrebleu
...@@ -12,6 +13,8 @@ import sacrebleu ...@@ -12,6 +13,8 @@ import sacrebleu
from lm_eval.api.registry import register_aggregation, register_metric from lm_eval.api.registry import register_aggregation, register_metric
T = TypeVar("T")
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -287,7 +290,7 @@ def pop_stddev(arr): ...@@ -287,7 +290,7 @@ def pop_stddev(arr):
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr): def sample_stddev(arr: Sequence[T]) -> float:
mu = mean(arr) mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
...@@ -449,11 +452,16 @@ def _sacreformat(refs, preds): ...@@ -449,11 +452,16 @@ def _sacreformat(refs, preds):
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n) -> None: """
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
"""
def __init__(self, f: Callable[[Sequence[T]], float], n: int) -> None:
self.f = f self.f = f
self.n = n self.n = n
def __call__(self, v): def __call__(self, v: tuple[int, Sequence[T]]) -> list[float]:
i, xs = v i, xs = v
rnd = random.Random() rnd = random.Random()
rnd.seed(i) rnd.seed(i)
...@@ -463,36 +471,81 @@ class _bootstrap_internal: ...@@ -463,36 +471,81 @@ class _bootstrap_internal:
return res return res
def bootstrap_stderr(f, xs, iters): def _bootstrap_internal_no_mp(
import multiprocessing as mp f: Callable[[Sequence[T]], float], xs: Sequence[T], iters: int
) -> list[float]:
pool = mp.Pool(mp.cpu_count()) """
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something Single-process fallback: compute `iters` bootstrap replicates
# equivalent to stderr calculated without Bessel's correction in the stddev. of statistic`f(xs)`, chunked (≤ 1000 draws).
# Unfortunately, I haven't been able to figure out what the right correction is """
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
res = [] res = []
chunk_size = min(1000, iters) chunk_size = min(1000, iters)
from tqdm import tqdm from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__) print(f"bootstrapping for stddev: {f.__name__}")
for bootstrap in tqdm(
pool.imap( # A single loop replaces the multiprocessing pool.
_bootstrap_internal(f, chunk_size), for i in tqdm(range(iters // chunk_size)):
[(i, xs) for i in range(iters // chunk_size)], rnd = random.Random(i)
), for _ in range(chunk_size):
total=iters // chunk_size, res.append(f(rnd.choices(xs, k=len(xs))))
):
# sample w replacement return res
res.extend(bootstrap)
pool.close() def bootstrap_stderr(
f: Callable[[Sequence[T]], float], xs: Sequence[T], iters: int
) -> float:
"""
Bootstrap estimate of the standard error of statistic `f(xs)`
using up to `iters` resamples, chunked (≤ 1000 draws)
Executes in parallel unless the env-var `DISABLE_MULTIPROC` is set;
"""
if not os.getenv("DISABLE_MULTIPROC"):
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev.
# Unfortunately, I haven't been able to figure out what the right correction is
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
res = []
chunk_size = min(1000, iters)
from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement
res.extend(bootstrap)
pool.close()
else:
res = _bootstrap_internal_no_mp(f, xs, iters)
return sample_stddev(res) return sample_stddev(res)
def stderr_for_metric(metric, bootstrap_iters: int): def stderr_for_metric(
metric: Callable[[Sequence[T]], float], bootstrap_iters: int
) -> Optional[Callable[[Sequence[T]], float]]:
"""
Return a function that estimates the standard error of `metric(xs)`.
* If `bootstrap_iters > 0` and the metric is in the pre-approved
bootstrappable list, use `bootstrap_stderr` with that many draws.
* If the metric has a closed-form SE (e.g. `mean`, `acc_all`), use it.
* Otherwise, return `None`.
"""
if bootstrap_iters <= 0: if bootstrap_iters <= 0:
# return no function (don't compute stderr) if bootstrap iters = 0 # return no function (don't compute stderr) if bootstrap iters = 0
return None return None
......
import unittest.mock as mock
from lm_eval.api.metrics import _bootstrap_internal_no_mp, mean
from lm_eval.api.task import ConfigurableTask, TaskConfig from lm_eval.api.task import ConfigurableTask, TaskConfig
...@@ -149,8 +152,34 @@ def test_acc_mutual_info_without_metric(): ...@@ -149,8 +152,34 @@ def test_acc_mutual_info_without_metric():
assert result_dict["acc"] == 1.0 assert result_dict["acc"] == 1.0
def test_bootstrap_internal_no_mp():
"""Test basic functionality of _bootstrap_internal_no_mp"""
data = [1, 2, 3, 4, 5]
# Mock tqdm to avoid progress bar output during testing
with mock.patch("tqdm.tqdm") as mock_tqdm:
mock_tqdm.return_value = range(1) # Single chunk
# Mock print to avoid output during testing
with mock.patch("builtins.print"):
result = _bootstrap_internal_no_mp(mean, data, 100)
# Should return 100 bootstrap replicates
assert len(result) == 100
# All results should be numbers (means)
assert all(isinstance(x, (int, float)) for x in result)
# Bootstrap means should be close to original mean
bootstrap_mean = mean(result)
original_mean = mean(data)
assert abs(bootstrap_mean - original_mean) < 0.5 # Should be reasonably close
if __name__ == "__main__": if __name__ == "__main__":
test_acc_mutual_info_slicing() test_acc_mutual_info_slicing()
test_acc_mutual_info_different_predictions() test_acc_mutual_info_different_predictions()
test_acc_mutual_info_without_metric() test_acc_mutual_info_without_metric()
test_bootstrap_internal_no_mp()
print("All tests passed!") print("All tests passed!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment