Unverified Commit d693dcd2 authored by Felipe Maia Polo's avatar Felipe Maia Polo Committed by GitHub
Browse files

Add `--samples` Argument for Fine-Grained Task Evaluation in...


 Add `--samples` Argument for Fine-Grained Task Evaluation in `lm-evaluation-harness`. This feature is the first step towards efficient multi-prompt evaluation with PromptEval [1,2] (#2520)

* added option --examples

* specifying examples in dictionary

* run pre-commit - fix arg type

Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com

* fixing bug when examples==None

* fixing bug when examples==None

* limit or examples must be None in simple_evaluate.py and in evaluator.py

* run pre-commit (fix formatting)

Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com

* merge main and run pre-commit (fix formatting)

Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com

* Update __main__.py

undefined "limit" and "examples"

* update branch, fix conflicts, run pre-commit

* nits

* nits

* change 'examples' to 'samples'

---------

Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com
Co-authored-by: default avatarmirianfrsilva <mirianfrsilva@ibm.com>
Co-authored-by: default avatarStella Biderman <stellabiderman@gmail.com>
Co-authored-by: default avatarBaber <baber@hey.com>
parent 11ac352d
......@@ -4,6 +4,7 @@ import logging
import os
import sys
from functools import partial
from pathlib import Path
from typing import Union
from lm_eval import evaluator, utils
......@@ -145,6 +146,14 @@ def setup_parser() -> argparse.ArgumentParser:
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument(
"--samples",
"-E",
default=None,
type=str,
metavar="/path/to/json",
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
)
parser.add_argument(
"--use_cache",
"-c",
......@@ -360,6 +369,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.samples:
assert args.limit is None, (
"If --samples is not None, then --limit must be None."
)
if (samples := Path(args.samples)).is_file():
args.samples = json.loads(samples.read_text())
else:
args.samples = json.loads(args.samples)
if args.tasks is None:
eval_logger.error("Need to specify task to evaluate.")
......@@ -419,10 +436,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
args.model_args = args.model_args + ",trust_remote_code=True"
eval_logger.info(
f"Selected Tasks: {task_names}"
) if eval_logger.getEffectiveLevel() >= logging.INFO else print(
f"Selected Tasks: {task_names}"
(
eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO
else print(f"Selected Tasks: {task_names}")
)
request_caching_args = request_caching_arg_to_dict(
......@@ -439,6 +456,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
device=args.device,
use_cache=args.use_cache,
limit=args.limit,
samples=args.samples,
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
......
......@@ -384,6 +384,7 @@ class Task(abc.ABC):
self,
*,
limit: Union[int, None] = None,
samples: Optional[List[int]] = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
......@@ -436,7 +437,9 @@ class Task(abc.ABC):
limit = None
doc_id_docs = list(
self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
self.doc_iterator(
rank=rank, limit=limit, samples=samples, world_size=world_size
)
)
num_docs = len(doc_id_docs)
......@@ -684,15 +687,35 @@ class Task(abc.ABC):
)
def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
self,
*,
rank: int = 0,
limit: Union[int, None] = None,
world_size: int = 1,
samples: Optional[List[int]] = None,
) -> Iterator[Tuple[int, Any]]:
limit = int(limit) if limit else None
doc_iterator = utils.create_iterator(
enumerate(self.eval_docs),
rank=int(rank),
limit=limit,
world_size=int(world_size),
)
if samples:
n = len(self.eval_docs)
assert all([e < n for e in samples]), (
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
)
eval_logger.info(
f"{self.config.task}: Evaluating on {len(samples)} examples"
)
doc_iterator = utils.create_iterator(
enumerate(x for i, x in enumerate(self.eval_docs) if i in samples),
rank=int(rank),
limit=None, # limit does not matter here since we are selecting samples directly
world_size=int(world_size),
)
else:
limit = int(limit) if limit else None
doc_iterator = utils.create_iterator(
enumerate(self.eval_docs),
rank=int(rank),
limit=limit,
world_size=int(world_size),
)
return doc_iterator
......
......@@ -26,10 +26,7 @@ from lm_eval.evaluator_utils import (
)
from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import (
TaskManager,
get_task_dict,
)
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import (
handle_non_serializable,
hash_string,
......@@ -60,6 +57,7 @@ def simple_evaluate(
rewrite_requests_cache: bool = False,
delete_requests_cache: bool = False,
limit: Optional[Union[int, float]] = None,
samples: Optional[dict] = None,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
write_out: bool = False,
......@@ -106,6 +104,8 @@ def simple_evaluate(
Deletes all the request cache if set to `True`. `None` if not desired.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param samples: dictionary, optional
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
:param bootstrap_iters:
Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
:param check_integrity: bool
......@@ -148,6 +148,11 @@ def simple_evaluate(
setup_logging(verbosity=verbosity)
start_date = time.time()
if limit is not None and samples is not None:
raise ValueError(
"Either 'limit' or 'samples' must be None, but both are not None."
)
if isinstance(model_args, str) and (
"instruct" in model_args and not apply_chat_template
):
......@@ -334,6 +339,7 @@ def simple_evaluate(
lm=lm,
task_dict=task_dict,
limit=limit,
samples=samples,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters,
......@@ -396,6 +402,7 @@ def evaluate(
lm: "LM",
task_dict,
limit: Optional[int] = None,
samples: Optional[dict] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
bootstrap_iters: Optional[int] = 100000,
......@@ -415,6 +422,8 @@ def evaluate(
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param samples: dictionary, optional
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
:param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests.
:param rewrite_requests_cache: bool, optional
......@@ -442,11 +451,16 @@ def evaluate(
Dictionary of results
"""
if limit is not None and samples is not None:
raise ValueError(
"Either 'limit' or 'samples' must be None, but both are not None."
)
if samples is not None:
eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
if apply_chat_template:
eval_logger.warning(
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
)
# tracks all Instances/requests a model must generate output on.
requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that
......@@ -496,6 +510,9 @@ def evaluate(
limits.append(limit)
task.build_all_requests(
limit=limit,
samples=samples.get(task_output.task_name, None)
if samples is not None
else samples,
rank=lm.rank,
world_size=lm.world_size,
cache_requests=cache_requests,
......@@ -579,10 +596,22 @@ def evaluate(
instances.sort(key=lambda x: x.idx)
# iterate over different filters used
for filter_key in task.instances[0].filtered_resps.keys():
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK, limit=limit, world_size=WORLD_SIZE
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
if indices:
doc_id_true = indices[doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id]
metrics = task.process_results(
doc, [req.filtered_resps[filter_key] for req in requests]
......@@ -590,7 +619,7 @@ def evaluate(
if log_samples:
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id,
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment