"docs/vscode:/vscode.git/clone" did not exist on "6028613226e6048c3bee1f306678fca61565f20d"
Unverified Commit d693dcd2 authored by Felipe Maia Polo's avatar Felipe Maia Polo Committed by GitHub
Browse files

Add `--samples` Argument for Fine-Grained Task Evaluation in...


 Add `--samples` Argument for Fine-Grained Task Evaluation in `lm-evaluation-harness`. This feature is the first step towards efficient multi-prompt evaluation with PromptEval [1,2] (#2520)

* added option --examples

* specifying examples in dictionary

* run pre-commit - fix arg type

Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com

* fixing bug when examples==None

* fixing bug when examples==None

* limit or examples must be None in simple_evaluate.py and in evaluator.py

* run pre-commit (fix formatting)

Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com

* merge main and run pre-commit (fix formatting)

Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com

* Update __main__.py

undefined "limit" and "examples"

* update branch, fix conflicts, run pre-commit

* nits

* nits

* change 'examples' to 'samples'

---------

Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com
Co-authored-by: default avatarmirianfrsilva <mirianfrsilva@ibm.com>
Co-authored-by: default avatarStella Biderman <stellabiderman@gmail.com>
Co-authored-by: default avatarBaber <baber@hey.com>
parent 11ac352d
...@@ -4,6 +4,7 @@ import logging ...@@ -4,6 +4,7 @@ import logging
import os import os
import sys import sys
from functools import partial from functools import partial
from pathlib import Path
from typing import Union from typing import Union
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
...@@ -145,6 +146,14 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -145,6 +146,14 @@ def setup_parser() -> argparse.ArgumentParser:
help="Limit the number of examples per task. " help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.", "If <1, limit is a percentage of the total number of examples.",
) )
parser.add_argument(
"--samples",
"-E",
default=None,
type=str,
metavar="/path/to/json",
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
)
parser.add_argument( parser.add_argument(
"--use_cache", "--use_cache",
"-c", "-c",
...@@ -360,6 +369,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -360,6 +369,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
" --limit SHOULD ONLY BE USED FOR TESTING." " --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.samples:
assert args.limit is None, (
"If --samples is not None, then --limit must be None."
)
if (samples := Path(args.samples)).is_file():
args.samples = json.loads(samples.read_text())
else:
args.samples = json.loads(args.samples)
if args.tasks is None: if args.tasks is None:
eval_logger.error("Need to specify task to evaluate.") eval_logger.error("Need to specify task to evaluate.")
...@@ -419,10 +436,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -419,10 +436,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
args.model_args = args.model_args + ",trust_remote_code=True" args.model_args = args.model_args + ",trust_remote_code=True"
eval_logger.info( (
f"Selected Tasks: {task_names}" eval_logger.info(f"Selected Tasks: {task_names}")
) if eval_logger.getEffectiveLevel() >= logging.INFO else print( if eval_logger.getEffectiveLevel() >= logging.INFO
f"Selected Tasks: {task_names}" else print(f"Selected Tasks: {task_names}")
) )
request_caching_args = request_caching_arg_to_dict( request_caching_args = request_caching_arg_to_dict(
...@@ -439,6 +456,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -439,6 +456,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
device=args.device, device=args.device,
use_cache=args.use_cache, use_cache=args.use_cache,
limit=args.limit, limit=args.limit,
samples=args.samples,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_out=args.write_out, write_out=args.write_out,
log_samples=args.log_samples, log_samples=args.log_samples,
......
...@@ -384,6 +384,7 @@ class Task(abc.ABC): ...@@ -384,6 +384,7 @@ class Task(abc.ABC):
self, self,
*, *,
limit: Union[int, None] = None, limit: Union[int, None] = None,
samples: Optional[List[int]] = None,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
cache_requests: bool = False, cache_requests: bool = False,
...@@ -436,7 +437,9 @@ class Task(abc.ABC): ...@@ -436,7 +437,9 @@ class Task(abc.ABC):
limit = None limit = None
doc_id_docs = list( doc_id_docs = list(
self.doc_iterator(rank=rank, limit=limit, world_size=world_size) self.doc_iterator(
rank=rank, limit=limit, samples=samples, world_size=world_size
)
) )
num_docs = len(doc_id_docs) num_docs = len(doc_id_docs)
...@@ -684,15 +687,35 @@ class Task(abc.ABC): ...@@ -684,15 +687,35 @@ class Task(abc.ABC):
) )
def doc_iterator( def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1 self,
*,
rank: int = 0,
limit: Union[int, None] = None,
world_size: int = 1,
samples: Optional[List[int]] = None,
) -> Iterator[Tuple[int, Any]]: ) -> Iterator[Tuple[int, Any]]:
limit = int(limit) if limit else None if samples:
doc_iterator = utils.create_iterator( n = len(self.eval_docs)
enumerate(self.eval_docs), assert all([e < n for e in samples]), (
rank=int(rank), f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
limit=limit, )
world_size=int(world_size), eval_logger.info(
) f"{self.config.task}: Evaluating on {len(samples)} examples"
)
doc_iterator = utils.create_iterator(
enumerate(x for i, x in enumerate(self.eval_docs) if i in samples),
rank=int(rank),
limit=None, # limit does not matter here since we are selecting samples directly
world_size=int(world_size),
)
else:
limit = int(limit) if limit else None
doc_iterator = utils.create_iterator(
enumerate(self.eval_docs),
rank=int(rank),
limit=limit,
world_size=int(world_size),
)
return doc_iterator return doc_iterator
......
...@@ -26,10 +26,7 @@ from lm_eval.evaluator_utils import ( ...@@ -26,10 +26,7 @@ from lm_eval.evaluator_utils import (
) )
from lm_eval.loggers import EvaluationTracker from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import ( from lm_eval.tasks import TaskManager, get_task_dict
TaskManager,
get_task_dict,
)
from lm_eval.utils import ( from lm_eval.utils import (
handle_non_serializable, handle_non_serializable,
hash_string, hash_string,
...@@ -60,6 +57,7 @@ def simple_evaluate( ...@@ -60,6 +57,7 @@ def simple_evaluate(
rewrite_requests_cache: bool = False, rewrite_requests_cache: bool = False,
delete_requests_cache: bool = False, delete_requests_cache: bool = False,
limit: Optional[Union[int, float]] = None, limit: Optional[Union[int, float]] = None,
samples: Optional[dict] = None,
bootstrap_iters: int = 100000, bootstrap_iters: int = 100000,
check_integrity: bool = False, check_integrity: bool = False,
write_out: bool = False, write_out: bool = False,
...@@ -106,6 +104,8 @@ def simple_evaluate( ...@@ -106,6 +104,8 @@ def simple_evaluate(
Deletes all the request cache if set to `True`. `None` if not desired. Deletes all the request cache if set to `True`. `None` if not desired.
:param limit: int or float, optional :param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param samples: dictionary, optional
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed. Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
:param check_integrity: bool :param check_integrity: bool
...@@ -148,6 +148,11 @@ def simple_evaluate( ...@@ -148,6 +148,11 @@ def simple_evaluate(
setup_logging(verbosity=verbosity) setup_logging(verbosity=verbosity)
start_date = time.time() start_date = time.time()
if limit is not None and samples is not None:
raise ValueError(
"Either 'limit' or 'samples' must be None, but both are not None."
)
if isinstance(model_args, str) and ( if isinstance(model_args, str) and (
"instruct" in model_args and not apply_chat_template "instruct" in model_args and not apply_chat_template
): ):
...@@ -334,6 +339,7 @@ def simple_evaluate( ...@@ -334,6 +339,7 @@ def simple_evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
limit=limit, limit=limit,
samples=samples,
cache_requests=cache_requests, cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
...@@ -396,6 +402,7 @@ def evaluate( ...@@ -396,6 +402,7 @@ def evaluate(
lm: "LM", lm: "LM",
task_dict, task_dict,
limit: Optional[int] = None, limit: Optional[int] = None,
samples: Optional[dict] = None,
cache_requests: bool = False, cache_requests: bool = False,
rewrite_requests_cache: bool = False, rewrite_requests_cache: bool = False,
bootstrap_iters: Optional[int] = 100000, bootstrap_iters: Optional[int] = 100000,
...@@ -415,6 +422,8 @@ def evaluate( ...@@ -415,6 +422,8 @@ def evaluate(
Dictionary of tasks. Tasks will be taken to have name type(task).config.task . Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
:param limit: int, optional :param limit: int, optional
Limit the number of examples per task (only use this for testing) Limit the number of examples per task (only use this for testing)
:param samples: dictionary, optional
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
:param cache_requests: bool, optional :param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests. Speed up evaluation by caching the building of dataset requests.
:param rewrite_requests_cache: bool, optional :param rewrite_requests_cache: bool, optional
...@@ -442,11 +451,16 @@ def evaluate( ...@@ -442,11 +451,16 @@ def evaluate(
Dictionary of results Dictionary of results
""" """
if limit is not None and samples is not None:
raise ValueError(
"Either 'limit' or 'samples' must be None, but both are not None."
)
if samples is not None:
eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
if apply_chat_template: if apply_chat_template:
eval_logger.warning( eval_logger.warning(
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details." "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
) )
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = defaultdict(list) requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
...@@ -496,6 +510,9 @@ def evaluate( ...@@ -496,6 +510,9 @@ def evaluate(
limits.append(limit) limits.append(limit)
task.build_all_requests( task.build_all_requests(
limit=limit, limit=limit,
samples=samples.get(task_output.task_name, None)
if samples is not None
else samples,
rank=lm.rank, rank=lm.rank,
world_size=lm.world_size, world_size=lm.world_size,
cache_requests=cache_requests, cache_requests=cache_requests,
...@@ -579,10 +596,22 @@ def evaluate( ...@@ -579,10 +596,22 @@ def evaluate(
instances.sort(key=lambda x: x.idx) instances.sort(key=lambda x: x.idx)
# iterate over different filters used # iterate over different filters used
for filter_key in task.instances[0].filtered_resps.keys(): for filter_key in task.instances[0].filtered_resps.keys():
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator( doc_iterator = task.doc_iterator(
rank=RANK, limit=limit, world_size=WORLD_SIZE rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
) )
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
if indices:
doc_id_true = indices[doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id] requests = instances_by_doc_id[doc_id]
metrics = task.process_results( metrics = task.process_results(
doc, [req.filtered_resps[filter_key] for req in requests] doc, [req.filtered_resps[filter_key] for req in requests]
...@@ -590,7 +619,7 @@ def evaluate( ...@@ -590,7 +619,7 @@ def evaluate(
if log_samples: if log_samples:
target = task.doc_to_target(doc) target = task.doc_to_target(doc)
example = { example = {
"doc_id": doc_id, "doc_id": doc_id_true,
"doc": doc, "doc": doc,
"target": target, "target": target,
"arguments": [req.args for req in requests], "arguments": [req.args for req in requests],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment