"vscode:/vscode.git/clone" did not exist on "3fd38de11d10cbd6f97c9a77fe86432b1e082e12"
Commit 8335e43a authored by Baber's avatar Baber
Browse files

add batch_size to `get_sample_size`

parent fb963f0f
......@@ -436,7 +436,7 @@ def evaluate(
for task_output in eval_tasks:
task: Task = task_output.task
limit = get_sample_size(task, limit)
limit = get_sample_size(task, limit, getattr(lm, "batch_size", None))
task.build_all_requests(
limit=limit,
rank=lm.rank,
......
......@@ -4,6 +4,8 @@ import pathlib
import sys
from typing import List, Optional, Tuple, Union
from pandas.core.dtypes.inference import is_float
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import (
aggregate_subtask_metrics,
......@@ -200,8 +202,16 @@ def print_writeout(task) -> None:
eval_logger.info(f"Request: {str(inst)}")
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
def get_sample_size(
task, limit: Optional[int], batch_size: Optional[int]
) -> Union[int, None]:
if limit is not None:
if batch_size is not None and is_float(limit) and limit == 1.0:
eval_logger.warning(
"Limit is 1.0, adjusting the sample size to be a multiple of the batch size"
)
return (len(task.eval_docs) // batch_size) * batch_size
elif limit is not None:
limit = (
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment