Unverified Commit b043b050 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Fix for bootstrap_iters = 0 case (#1715) (#1789)

* add handling for bootstrap_iters=0 case

* add more detail to docstring

* run precommit
parent 7d747ea9
......@@ -429,7 +429,11 @@ def bootstrap_stderr(f, xs, iters):
return sample_stddev(res)
def stderr_for_metric(metric, bootstrap_iters):
def stderr_for_metric(metric, bootstrap_iters: int):
if bootstrap_iters <= 0:
# return no function (don't compute stderr) if bootstrap iters = 0
return None
bootstrappable = [
median,
matthews_corrcoef,
......
......@@ -92,7 +92,7 @@ def simple_evaluate(
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
Number of iterations for bootstrap statistics
Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_out: bool
......@@ -328,7 +328,7 @@ def evaluate(
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param bootstrap_iters:
Number of iterations for bootstrap statistics
Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
:param write_out: bool
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
......
......@@ -97,7 +97,7 @@ class TaskOutput:
metric_key = f"{metric},{filter_key}"
self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric?
if bootstrap_iters:
if isinstance(bootstrap_iters, int):
stderr_fn = metrics.stderr_for_metric(
metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100)
......@@ -107,6 +107,10 @@ class TaskOutput:
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
)
else:
raise ValueError(
f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
)
def __repr__(self):
return (
......
......@@ -14,31 +14,33 @@ from lm_eval import tasks
@pytest.mark.parametrize(
"task_name,limit,model,model_args",
"task_name,limit,model,model_args,bootstrap_iters",
[
(
["arc_easy"],
10,
"hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
0,
),
(
["mmlu_abstract_algebra"],
None,
"hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
10000,
),
],
)
def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str):
# task_name = task_name
# limit = 10
def test_evaluator(
task_name: List[str], limit: int, model: str, model_args: str, bootstrap_iters: int
):
e1 = evaluator.simple_evaluate(
model=model,
tasks=task_name,
limit=limit,
model_args=model_args,
bootstrap_iters=bootstrap_iters,
)
assert e1 is not None
......@@ -57,6 +59,7 @@ def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str
lm=lm,
task_dict=task_dict,
limit=limit,
bootstrap_iters=bootstrap_iters,
)
assert e2 is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment