import argparse import json import logging import os import sys from functools import partial from pathlib import Path from typing import Union def try_parse_json(value: str) -> Union[str, dict, None]: if value is None: return None try: return json.loads(value) except json.JSONDecodeError: if "{" in value: raise argparse.ArgumentTypeError( f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings." ) return value def _int_or_none_list_arg_type( min_len: int, max_len: int, defaults: str, value: str, split_char: str = "," ): def parse_value(item): item = item.strip().lower() if item == "none": return None try: return int(item) except ValueError: raise argparse.ArgumentTypeError(f"{item} is not an integer or None") items = [parse_value(v) for v in value.split(split_char)] num_items = len(items) if num_items == 1: # Makes downstream handling the same for single and multiple values items = items * max_len elif num_items < min_len or num_items > max_len: raise argparse.ArgumentTypeError( f"Argument requires {max_len} integers or None, separated by '{split_char}'" ) elif num_items != max_len: logging.warning( f"Argument requires {max_len} integers or None, separated by '{split_char}'. " "Missing values will be filled with defaults." ) default_items = [parse_value(v) for v in defaults.split(split_char)] items.extend( default_items[num_items:] ) # extend items list with missing defaults return items def check_argument_types(parser: argparse.ArgumentParser): """ Check to make sure all CLI args are typed, raises error if not """ for action in parser._actions: if action.dest != "help" and not action.const: if action.type is None: raise ValueError( f"Argument '{action.dest}' doesn't have a type specified." ) else: continue def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser.add_argument( "--config", "-C", default=None, type=str, metavar="DIR/file.yaml", action=TrackExplicitAction, help="Path to config with all arguments for `lm-eval`", ) parser.add_argument( "--model", "-m", type=str, default="hf", action=TrackExplicitAction, help="Name of model e.g. `hf`", ) parser.add_argument( "--tasks", "-t", default=None, type=str, action=TrackExplicitAction, metavar="task1,task2", help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above", ) parser.add_argument( "--model_args", "-a", default="", action=TrackExplicitAction, type=try_parse_json, help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""", ) parser.add_argument( "--num_fewshot", "-f", type=int, default=None, action=TrackExplicitAction, metavar="N", help="Number of examples in few-shot context", ) parser.add_argument( "--batch_size", "-b", type=str, action=TrackExplicitAction, default=1, metavar="auto|auto:N|N", help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", ) parser.add_argument( "--max_batch_size", type=int, default=None, action=TrackExplicitAction, metavar="N", help="Maximal batch size to try with --batch_size auto.", ) parser.add_argument( "--device", type=str, default=None, action=TrackExplicitAction, help="Device to use (e.g. cuda, cuda:0, cpu).", ) parser.add_argument( "--output_path", "-o", default=None, type=str, action=TrackExplicitAction, metavar="DIR|DIR/file.json", help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.", ) parser.add_argument( "--limit", "-L", type=float, default=None, action=TrackExplicitAction, metavar="N|0 argparse.Namespace: check_argument_types(parser) return parser.parse_args() def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: if not args: # we allow for args to be passed externally, else we parse them ourselves parser = setup_parser() args = parse_eval_args(parser) cfg = EvaluationConfig.from_cli(args) # defer loading `lm_eval` submodules for faster CLI load from lm_eval import evaluator, utils from lm_eval.evaluator import request_caching_arg_to_dict from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.tasks import TaskManager from lm_eval.utils import ( handle_non_serializable, make_table, simple_parse_args_string, ) if args.wandb_args: wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args) utils.setup_logging(cfg.verbosity) eval_logger = logging.getLogger(__name__) os.environ["TOKENIZERS_PARALLELISM"] = "false" # update the evaluation tracker args with the output path and the HF token if cfg.output_path: cfg.hf_hub_log_args["output_path"] = cfg.output_path if os.environ.get("HF_TOKEN", None): cfg.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN") evaluation_tracker_args = cfg.hf_hub_log_args evaluation_tracker = EvaluationTracker(**evaluation_tracker_args) if cfg.predict_only: cfg.log_samples = True if (cfg.log_samples or cfg.predict_only) and not cfg.output_path: raise ValueError( "Specify --output_path if providing --log_samples or --predict_only" ) if cfg.fewshot_as_multiturn and cfg.apply_chat_template is False: raise ValueError( "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." ) if cfg.include_path is not None: eval_logger.info(f"Including path: {cfg.include_path}") metadata = (cfg.model_args) | (cfg.metadata) cfg.metadata = metadata # task_manager = TaskManager(include_path=config["include_path"], metadata=metadata) task_manager = TaskManager(include_path=cfg.include_path, metadata=metadata) if "push_samples_to_hub" in evaluation_tracker_args and not cfg.log_samples: eval_logger.warning( "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub." ) if cfg.limit: eval_logger.warning( " --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." ) if cfg.samples: assert cfg.limit is None, "If --samples is not None, then --limit must be None." if (samples := Path(cfg.samples)).is_file(): cfg.samples = json.loads(samples.read_text()) else: cfg.samples = json.loads(cfg.samples) if cfg.tasks is None: eval_logger.error("Need to specify task to evaluate.") sys.exit() elif cfg.tasks == "list": print(task_manager.list_all_tasks()) sys.exit() elif cfg.tasks == "list_groups": print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) sys.exit() elif cfg.tasks == "list_tags": print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) sys.exit() elif cfg.tasks == "list_subtasks": print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) sys.exit() else: if os.path.isdir(cfg.tasks): import glob task_names = [] yaml_path = os.path.join(cfg.tasks, "*.yaml") for yaml_file in glob.glob(yaml_path): cfg = utils.load_yaml_config(yaml_file) task_names.append(cfg) else: task_list = cfg.tasks.split(",") task_names = task_manager.match_tasks(task_list) for task in [task for task in task_list if task not in task_names]: if os.path.isfile(task): cfg = utils.load_yaml_config(task) task_names.append(cfg) task_missing = [ task for task in task_list if task not in task_names and "*" not in task ] # we don't want errors if a wildcard ("*") task name was used if task_missing: missing = ", ".join(task_missing) eval_logger.error( f"Tasks were not found: {missing}\n" f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", ) raise ValueError( f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues." ) cfg.tasks = task_names # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args if cfg.trust_remote_code: eval_logger.info( "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`" ) # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, # because it's already been determined based on the prior env var before launching our # script--`datasets` gets imported by lm_eval internally before these lines can update the env. import datasets datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True cfg.model_args["trust_remote_code"] = True ( eval_logger.info(f"Selected Tasks: {task_names}") if eval_logger.getEffectiveLevel() >= logging.INFO else print(f"Selected Tasks: {task_names}") ) request_caching_args = request_caching_arg_to_dict( cache_requests=cfg.cache_requests ) cfg.request_caching_args = request_caching_args results = evaluator.simple_evaluate( model=cfg.model, model_args=cfg.model_args, tasks=cfg.tasks, num_fewshot=cfg.num_fewshot, batch_size=cfg.batch_size, max_batch_size=cfg.max_batch_size, device=cfg.device, use_cache=cfg.use_cache, cache_requests=cfg.request_caching_args.get("cache_requests", False), rewrite_requests_cache=cfg.request_caching_args.get( "rewrite_requests_cache", False ), delete_requests_cache=cfg.request_caching_args.get( "delete_requests_cache", False ), limit=cfg.limit, samples=cfg.samples, check_integrity=cfg.check_integrity, write_out=cfg.write_out, log_samples=cfg.log_samples, evaluation_tracker=evaluation_tracker, system_instruction=cfg.system_instruction, apply_chat_template=cfg.apply_chat_template, fewshot_as_multiturn=cfg.fewshot_as_multiturn, gen_kwargs=cfg.gen_kwargs, task_manager=task_manager, verbosity=cfg.verbosity, predict_only=cfg.predict_only, random_seed=cfg.seed[0] if cfg.seed else None, numpy_random_seed=cfg.seed[1] if cfg.seed else None, torch_random_seed=cfg.seed[2] if cfg.seed else None, fewshot_random_seed=cfg.seed[3] if cfg.seed else None, confirm_run_unsafe_code=cfg.confirm_run_unsafe_code, metadata=cfg.metadata, ) if results is not None: if cfg.log_samples: samples = results.pop("samples") dumped = json.dumps( results, indent=2, default=handle_non_serializable, ensure_ascii=False ) if cfg.show_config: print(dumped) batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) # Add W&B logging if cfg.wandb_args: try: wandb_logger.post_init(results) wandb_logger.log_eval_result() if cfg.log_samples: wandb_logger.log_eval_samples(samples) except Exception as e: eval_logger.info(f"Logging to Weights and Biases failed due to {e}") evaluation_tracker.save_results_aggregated( results=results, samples=samples if args.log_samples else None ) if cfg.log_samples: for task_name, _ in results["configs"].items(): evaluation_tracker.save_results_samples( task_name=task_name, samples=samples[task_name] ) if ( evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub ): evaluation_tracker.recreate_metadata_card() print( f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, " f"batch_size: {cfg.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" ) print(make_table(results)) if "groups" in results: print(make_table(results, "groups")) if cfg.wandb_args: # Tear down wandb run once all the logging is done. wandb_logger.run.finish() if __name__ == "__main__": cli_evaluate()