Commit b5d16d61 authored by artemorloff's avatar artemorloff
Browse files

enable evaluation from yaml config file

parent d693dcd2
model: vllm
model_args:
pretrained: Qwen/Qwen2.5-0.5B-Instruct
dtype: bfloat16
tensor_parallel_size: 1
tasks: hellaswag,gsm8k
batch_size: 1
trust_remote_code: true
log_samples: true
output_path: ./test
apply_chat_template: true
fewshot_as_multiturn: true
limit: 5
......@@ -12,9 +12,13 @@ from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import (
TrackExplicitAction,
TrackExplicitStoreTrue,
handle_non_serializable,
load_yaml_config,
make_table,
simple_parse_args_string,
non_default_update,
parse_namespace,
)
......@@ -83,13 +87,28 @@ def check_argument_types(parser: argparse.ArgumentParser):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
"--config",
"-C",
default=None,
type=str,
metavar="DIR/file.yaml",
action=TrackExplicitAction,
help="Path to config with all arguments for `lm-eval`",
)
parser.add_argument(
"--model",
"-m",
type=str,
default="hf",
action=TrackExplicitAction,
help="Name of model e.g. `hf`",
)
parser.add_argument(
"--tasks",
"-t",
default=None,
type=str,
action=TrackExplicitAction,
metavar="task1,task2",
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
)
......@@ -97,6 +116,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--model_args",
"-a",
default="",
action=TrackExplicitAction,
type=try_parse_json,
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
)
......@@ -105,6 +125,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-f",
type=int,
default=None,
action=TrackExplicitAction,
metavar="N",
help="Number of examples in few-shot context",
)
......@@ -112,6 +133,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--batch_size",
"-b",
type=str,
action=TrackExplicitAction,
default=1,
metavar="auto|auto:N|N",
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
......@@ -120,6 +142,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--max_batch_size",
type=int,
default=None,
action=TrackExplicitAction,
metavar="N",
help="Maximal batch size to try with --batch_size auto.",
)
......@@ -127,6 +150,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--device",
type=str,
default=None,
action=TrackExplicitAction,
help="Device to use (e.g. cuda, cuda:0, cpu).",
)
parser.add_argument(
......@@ -134,6 +158,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-o",
default=None,
type=str,
action=TrackExplicitAction,
metavar="DIR|DIR/file.json",
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
)
......@@ -142,6 +167,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-L",
type=float,
default=None,
action=TrackExplicitAction,
metavar="N|0<N<1",
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
......@@ -151,6 +177,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-E",
default=None,
type=str,
action=TrackExplicitAction,
metavar="/path/to/json",
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
)
......@@ -158,6 +185,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--use_cache",
"-c",
type=str,
action=TrackExplicitAction,
default=None,
metavar="DIR",
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
......@@ -166,25 +194,26 @@ def setup_parser() -> argparse.ArgumentParser:
"--cache_requests",
type=str,
default=None,
action=TrackExplicitAction,
choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
)
parser.add_argument(
"--check_integrity",
action="store_true",
action=TrackExplicitStoreTrue,
help="Whether to run the relevant part of the test suite for the tasks.",
)
parser.add_argument(
"--write_out",
"-w",
action="store_true",
action=TrackExplicitStoreTrue,
default=False,
help="Prints the prompt for the first few documents.",
)
parser.add_argument(
"--log_samples",
"-s",
action="store_true",
action=TrackExplicitStoreTrue,
default=False,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
)
......@@ -192,12 +221,14 @@ def setup_parser() -> argparse.ArgumentParser:
"--system_instruction",
type=str,
default=None,
action=TrackExplicitAction,
help="System instruction to be used in the prompt",
)
parser.add_argument(
"--apply_chat_template",
type=str,
nargs="?",
action=TrackExplicitAction,
const=True,
default=False,
help=(
......@@ -209,13 +240,13 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--fewshot_as_multiturn",
action="store_true",
action=TrackExplicitStoreTrue,
default=False,
help="If True, uses the fewshot as a multi-turn conversation",
)
parser.add_argument(
"--show_config",
action="store_true",
action=TrackExplicitStoreTrue,
default=False,
help="If True, shows the the full config of all tasks at the end of the evaluation.",
)
......@@ -223,6 +254,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--include_path",
type=str,
default=None,
action=TrackExplicitAction,
metavar="DIR",
help="Additional path to include if there are external tasks to include.",
)
......@@ -230,6 +262,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--gen_kwargs",
type=try_parse_json,
default=None,
action=TrackExplicitAction,
help=(
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
......@@ -240,6 +273,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-v",
type=str.upper,
default=None,
action=TrackExplicitAction,
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
)
......@@ -247,24 +281,27 @@ def setup_parser() -> argparse.ArgumentParser:
"--wandb_args",
type=str,
default="",
action=TrackExplicitAction,
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
)
parser.add_argument(
"--wandb_config_args",
type=str,
default="",
action=TrackExplicitAction,
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
)
parser.add_argument(
"--hf_hub_log_args",
type=str,
default="",
action=TrackExplicitAction,
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
)
parser.add_argument(
"--predict_only",
"-x",
action="store_true",
action=TrackExplicitStoreTrue,
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
......@@ -272,6 +309,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--seed",
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
action=TrackExplicitAction,
default=default_seed_string, # for backward compatibility
help=(
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
......@@ -286,18 +324,19 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
action=TrackExplicitStoreTrue,
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
parser.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
action=TrackExplicitStoreTrue,
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
parser.add_argument(
"--metadata",
type=json.loads,
default=None,
action=TrackExplicitAction,
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
)
return parser
......@@ -314,96 +353,96 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
parser = setup_parser()
args = parse_eval_args(parser)
# get namespace from console, including config arg
config, non_default_args = parse_namespace(args)
# if config is passed, load it
if config.get("config", False):
local_config = load_yaml_config(yaml_path=config["config"])
config = non_default_update(config, local_config, non_default_args)
if args.wandb_args:
wandb_args_dict = simple_parse_args_string(args.wandb_args)
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
wandb_logger = WandbLogger(config["wandb_args"], config["wandb_config_args"])
utils.setup_logging(args.verbosity)
utils.setup_logging(config["verbosity"])
# utils.setup_logging(args.verbosity)
eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# update the evaluation tracker args with the output path and the HF token
if args.output_path:
args.hf_hub_log_args += f",output_path={args.output_path}"
if config["output_path"]:
config.setdefault("hf_hub_log_args", {})["output_path"] = config["output_path"]
if os.environ.get("HF_TOKEN", None):
args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
config.setdefault("hf_hub_log_args", {})["token"] = os.environ.get("HF_TOKEN")
evaluation_tracker_args = config["hf_hub_log_args"]
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
if config["predict_only"]:
config["log_samples"] = True
if (config["log_samples"] or config["predict_only"]) and not config["output_path"]:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
if args.fewshot_as_multiturn and args.apply_chat_template is False:
if config["fewshot_as_multiturn"] and config["apply_chat_template"] is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
)
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
metadata = (
simple_parse_args_string(args.model_args)
if isinstance(args.model_args, str)
else args.model_args
if isinstance(args.model_args, dict)
else {}
) | (
args.metadata
if isinstance(args.metadata, dict)
else simple_parse_args_string(args.metadata)
)
if config["include_path"] is not None:
eval_logger.info(f"Including path: {config['include_path']}")
metadata = (config["model_args"]) | (config["metadata"])
task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
task_manager = TaskManager(include_path=config["include_path"], metadata=metadata)
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
if "push_samples_to_hub" in evaluation_tracker_args and not config["log_samples"]:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
)
if args.limit:
if config["limit"]:
eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.samples:
assert args.limit is None, (
if config["samples"]:
assert config["limit"] is None, (
"If --samples is not None, then --limit must be None."
)
if (samples := Path(args.samples)).is_file():
args.samples = json.loads(samples.read_text())
if (samples := Path(config["samples"])).is_file():
config["samples"] = json.loads(samples.read_text())
else:
args.samples = json.loads(args.samples)
config["samples"] = json.loads(config["samples"])
if args.tasks is None:
if config["tasks"] is None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
elif args.tasks == "list":
elif config["tasks"] == "list":
print(task_manager.list_all_tasks())
sys.exit()
elif args.tasks == "list_groups":
elif config["tasks"] == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
sys.exit()
elif args.tasks == "list_tags":
elif config["tasks"] == "list_tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
sys.exit()
elif args.tasks == "list_subtasks":
elif config["tasks"] == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
sys.exit()
else:
if os.path.isdir(args.tasks):
if os.path.isdir(config["tasks"]):
import glob
task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml")
yaml_path = os.path.join(config["tasks"], "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file)
task_names.append(config)
else:
task_list = args.tasks.split(",")
task_list = config["tasks"].split(",")
task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task):
......@@ -424,7 +463,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
)
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code:
if config["trust_remote_code"]:
eval_logger.info(
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
)
......@@ -435,7 +474,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
args.model_args = args.model_args + ",trust_remote_code=True"
config.setdefault("model_args", {})["trust_remote_code"] = True
(
eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO
......@@ -443,56 +482,56 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
)
request_caching_args = request_caching_arg_to_dict(
cache_requests=args.cache_requests
cache_requests=config["cache_requests"]
)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
model=config["model"],
model_args=config["model_args"],
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
use_cache=args.use_cache,
limit=args.limit,
samples=args.samples,
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
num_fewshot=config["num_fewshot"],
batch_size=config["batch_size"],
max_batch_size=config["max_batch_size"],
device=config["device"],
use_cache=config["use_cache"],
limit=config["limit"],
samples=config["samples"],
check_integrity=config["check_integrity"],
write_out=config["write_out"],
log_samples=config["log_samples"],
evaluation_tracker=evaluation_tracker,
system_instruction=args.system_instruction,
apply_chat_template=args.apply_chat_template,
fewshot_as_multiturn=args.fewshot_as_multiturn,
gen_kwargs=args.gen_kwargs,
system_instruction=config["system_instruction"],
apply_chat_template=config["apply_chat_template"],
fewshot_as_multiturn=config["fewshot_as_multiturn"],
gen_kwargs=config["gen_kwargs"],
task_manager=task_manager,
predict_only=args.predict_only,
random_seed=args.seed[0],
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
predict_only=config["predict_only"],
random_seed=config["seed"][0],
numpy_random_seed=config["seed"][1],
torch_random_seed=config["seed"][2],
fewshot_random_seed=config["seed"][3],
confirm_run_unsafe_code=config["confirm_run_unsafe_code"],
metadata=metadata,
**request_caching_args,
)
if results is not None:
if args.log_samples:
if config["log_samples"]:
samples = results.pop("samples")
dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
if args.show_config:
if config["show_config"]:
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# Add W&B logging
if args.wandb_args:
if config["wandb_args"]:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if args.log_samples:
if config["log_samples"]:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
......@@ -501,8 +540,8 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
results=results, samples=samples if args.log_samples else None
)
if args.log_samples:
for task_name, config in results["configs"].items():
if config["log_samples"]:
for task_name, _ in results["configs"].items():
evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name]
)
......@@ -514,14 +553,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
evaluation_tracker.recreate_metadata_card()
print(
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
f"{config['model']} ({config['model_args']}), gen_kwargs: ({config['gen_kwargs']}), limit: {config['limit']}, num_fewshot: {config['num_fewshot']}, "
f"batch_size: {config['batch_size']}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
if args.wandb_args:
if config["wandb_args"]:
# Tear down wandb run once all the logging is done.
wandb_logger.run.finish()
......
import argparse
import collections
import fnmatch
import functools
......@@ -11,7 +12,7 @@ import re
from dataclasses import asdict, is_dataclass
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import numpy as np
import yaml
......@@ -550,3 +551,96 @@ def weighted_f1_score(items):
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted")
return fscore
def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]:
"""
Convert an argparse Namespace object to a dictionary.
Handles all argument types including boolean flags, lists, and None values.
Only includes arguments that were explicitly set (ignores defaults unless used).
Args:
namespace: The argparse.Namespace object to convert
Returns:
Dictionary containing all the namespace arguments and their values
"""
config = {key: value for key, value in vars(namespace).items()}
for key in config:
# TODO: pass this list as param
if key in [
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
"metadata",
"model_args",
]:
if not isinstance(config[key], dict):
config[key] = simple_parse_args_string(config[key])
if "model_args" not in config:
config["model_args"] = {}
if "metadata" not in config:
config["metadata"] = {}
non_default_args = []
if hasattr(namespace, "_explicit_args"):
non_default_args = namespace._explicit_args
return config, non_default_args
class TrackExplicitAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
# Create a set on the namespace to track explicitly set arguments if it doesn't exist
if not hasattr(namespace, "_explicit_args"):
setattr(namespace, "_explicit_args", set())
# Record that this argument was explicitly provided
namespace._explicit_args.add(self.dest)
setattr(namespace, self.dest, values)
class TrackExplicitStoreTrue(argparse.Action):
def __init__(
self, option_strings, dest, nargs=0, const=True, default=False, **kwargs
):
# Ensure that nargs is 0, as required by store_true actions.
if nargs != 0:
raise ValueError("nargs for store_true actions must be 0")
super().__init__(
option_strings, dest, nargs=0, const=const, default=default, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
# Initialize or update the set of explicitly provided arguments.
if not hasattr(namespace, "_explicit_args"):
setattr(namespace, "_explicit_args", set())
namespace._explicit_args.add(self.dest)
setattr(namespace, self.dest, self.const)
def non_default_update(console_dict, local_dict, non_default_args):
"""
Update local_dict by taking non-default values from console_dict.
Args:
console_dict (dict): The dictionary that potentially provides updates.
local_dict (dict): The dictionary to be updated.
non_default_args (list): The list of args passed by user in console.
Returns:
dict: The updated local_dict.
"""
result_config = {}
for key in set(console_dict.keys()).union(local_dict.keys()):
if key in non_default_args:
result_config[key] = console_dict[key]
else:
if key in local_dict:
result_config[key] = local_dict[key]
else:
result_config[key] = console_dict[key]
return result_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment