import os import json import argparse from lm_eval import evaluator, utils from lm_eval.api.registry import ALL_TASKS from lm_eval.logger import eval_logger os.environ["TOKENIZERS_PARALLELISM"] = "false" def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) parser.add_argument("--model_args", default="") parser.add_argument("--tasks", default=None, choices=utils.MultiChoice(sorted(ALL_TASKS))) parser.add_argument("--config", default=None) parser.add_argument("--num_fewshot", type=int, default=0) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--max_batch_size", type=int, default=None, help="Maximal batch size to try with --batch_size auto") parser.add_argument("--device", type=str, default=None) parser.add_argument("--output_path", default=None) parser.add_argument("--limit", type=float, default=None, help="Limit the number of examples per task. " "If <1, limit is a percentage of the total number of examples.") parser.add_argument("--data_sampling", type=float, default=None) parser.add_argument("--no_cache", action="store_true") parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--description_dict_path", default=None) parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--write_out", action="store_true", default=False) parser.add_argument("--output_base_path", type=str, default=None) return parser.parse_args() def main(): args = parse_args() if args.limit: eval_logger.warning( " --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." ) if args.tasks is None: task_names = ALL_TASKS else: if os.path.isdir(args.tasks): import glob task_names = [] yaml_path = os.path.join(args.tasks, "*.yaml") for yaml_file in glob.glob(yaml_path): config = utils.load_yaml_config(yaml_file) task_names.append(config) else: tasks_list = args.tasks.split(",") task_names = utils.pattern_match(tasks_list, ALL_TASKS) for task in [task for task in tasks_list if task not in task_names]: if os.path.isfile(task): config = utils.load_yaml_config(task) task_names.append(config) eval_logger.info(f"Selected Tasks: {task_names}") # TODO: description_dict? # description_dict = {} # if args.description_dict_path: # with open(args.description_dict_path, "r") as f: # description_dict = json.load(f) results = evaluator.simple_evaluate( model=args.model, model_args=args.model_args, tasks=task_names, num_fewshot=args.num_fewshot, batch_size=args.batch_size, max_batch_size=args.max_batch_size, device=args.device, no_cache=args.no_cache, limit=args.limit, # description_dict=description_dict, decontamination_ngrams_path=args.decontamination_ngrams_path, check_integrity=args.check_integrity, write_out=args.write_out, output_base_path=args.output_base_path, ) if results is not None: dumped = json.dumps(results, indent=2) print(dumped) if args.output_path: os.makedirs(os.path.dirname(args.output_path), exist_ok=True) with open(args.output_path, "w") as f: f.write(dumped) batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) print( f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" ) print(evaluator.make_table(results)) if __name__ == "__main__": main()