"vscode:/vscode.git/clone" did not exist on "46c847c4ade869e84d7b5849748c51cff31d0bda"
Unverified Commit b043b050 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Fix for bootstrap_iters = 0 case (#1715) (#1789)

* add handling for bootstrap_iters=0 case

* add more detail to docstring

* run precommit
parent 7d747ea9
...@@ -429,7 +429,11 @@ def bootstrap_stderr(f, xs, iters): ...@@ -429,7 +429,11 @@ def bootstrap_stderr(f, xs, iters):
return sample_stddev(res) return sample_stddev(res)
def stderr_for_metric(metric, bootstrap_iters): def stderr_for_metric(metric, bootstrap_iters: int):
if bootstrap_iters <= 0:
# return no function (don't compute stderr) if bootstrap iters = 0
return None
bootstrappable = [ bootstrappable = [
median, median,
matthews_corrcoef, matthews_corrcoef,
......
...@@ -92,7 +92,7 @@ def simple_evaluate( ...@@ -92,7 +92,7 @@ def simple_evaluate(
:param limit: int or float, optional :param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
:param check_integrity: bool :param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks Whether to run the relevant part of the test suite for the tasks
:param write_out: bool :param write_out: bool
...@@ -328,7 +328,7 @@ def evaluate( ...@@ -328,7 +328,7 @@ def evaluate(
:param limit: int, optional :param limit: int, optional
Limit the number of examples per task (only use this for testing) Limit the number of examples per task (only use this for testing)
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
:param write_out: bool :param write_out: bool
If True, write out an example document and model input for checking task integrity If True, write out an example document and model input for checking task integrity
:param log_samples: bool :param log_samples: bool
......
...@@ -97,7 +97,7 @@ class TaskOutput: ...@@ -97,7 +97,7 @@ class TaskOutput:
metric_key = f"{metric},{filter_key}" metric_key = f"{metric},{filter_key}"
self.agg_metrics[metric_key] = agg_fn(items) self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric? self.sample_len = len(items) # TODO: same sample size for each metric?
if bootstrap_iters: if isinstance(bootstrap_iters, int):
stderr_fn = metrics.stderr_for_metric( stderr_fn = metrics.stderr_for_metric(
metric=agg_fn, metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100) bootstrap_iters=min(bootstrap_iters, 100)
...@@ -107,6 +107,10 @@ class TaskOutput: ...@@ -107,6 +107,10 @@ class TaskOutput:
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = ( self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A" stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
) )
else:
raise ValueError(
f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
)
def __repr__(self): def __repr__(self):
return ( return (
......
...@@ -14,31 +14,33 @@ from lm_eval import tasks ...@@ -14,31 +14,33 @@ from lm_eval import tasks
@pytest.mark.parametrize( @pytest.mark.parametrize(
"task_name,limit,model,model_args", "task_name,limit,model,model_args,bootstrap_iters",
[ [
( (
["arc_easy"], ["arc_easy"],
10, 10,
"hf", "hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu", "pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
0,
), ),
( (
["mmlu_abstract_algebra"], ["mmlu_abstract_algebra"],
None, None,
"hf", "hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu", "pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
10000,
), ),
], ],
) )
def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str): def test_evaluator(
# task_name = task_name task_name: List[str], limit: int, model: str, model_args: str, bootstrap_iters: int
# limit = 10 ):
e1 = evaluator.simple_evaluate( e1 = evaluator.simple_evaluate(
model=model, model=model,
tasks=task_name, tasks=task_name,
limit=limit, limit=limit,
model_args=model_args, model_args=model_args,
bootstrap_iters=bootstrap_iters,
) )
assert e1 is not None assert e1 is not None
...@@ -57,6 +59,7 @@ def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str ...@@ -57,6 +59,7 @@ def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
limit=limit, limit=limit,
bootstrap_iters=bootstrap_iters,
) )
assert e2 is not None assert e2 is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment