import os import json import fnmatch import argparse import logging from lm_eval import evaluator, utils from lm_eval.api.registry import GROUP_REGISTRY, TASK_REGISTRY from lm_eval.logger import eval_logger os.environ["TOKENIZERS_PARALLELISM"] = "false" ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys())) print("ALL tasks: ", ALL_TASKS) class MultiChoice: def __init__(self, choices): self.choices = choices # Simple wildcard support (linux filename patterns) def __contains__(self, values): for value in values.split(","): if len(fnmatch.filter(self.choices, value)) == 0: eval_logger.warning("{} is not in task list.".format(value)) eval_logger.info(f"Available tasks to choose:") # for choice in self.choices: # eval_logger.info(f" {choice}") eval_logger.info(ALL_TASKS) return True def __iter__(self): for choice in self.choices: yield choice def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) parser.add_argument("--model_args", default="") parser.add_argument("--tasks", default=None, choices=MultiChoice(ALL_TASKS)) parser.add_argument("--config", default=None) parser.add_argument("--provide_description", action="store_true") parser.add_argument("--num_fewshot", type=int, default=0) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--device", type=str, default=None) parser.add_argument("--output_path", default=None) parser.add_argument("--limit", type=int, default=None) parser.add_argument("--no_cache", action="store_true") parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--description_dict_path", default=None) parser.add_argument("--check_integrity", action="store_true") return parser.parse_args() # Returns a list containing all values of the source_list that # match at least one of the patterns def pattern_match(patterns, source_list): task_names = set() for pattern in patterns: for matching in fnmatch.filter(source_list, pattern): task_names.add(matching) return sorted(list(task_names)) def setup_example_logger(output_path): """Sets up a logger that will save each example and prediction.""" example_logger = logging.getLogger("examples") if output_path: filename = f"./{os.path.dirname(output_path)}/examples.jsonl" formatter = logging.Formatter("%(message)s") handler = logging.FileHandler(filename) handler.setFormatter(formatter) example_logger.addHandler(handler) example_logger.setLevel(logging.INFO) def main(): args = parse_args() if args.limit: eval_logger.warning( " --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." ) if args.output_path: os.makedirs(os.path.dirname(args.output_path), exist_ok=True) setup_example_logger(args.output_path) if args.tasks is not None: if os.path.isdir(args.tasks): import glob task_names = [] yaml_path = os.path.join(args.tasks, "*.yaml") for yaml_file in glob.glob(yaml_path): config = utils.load_yaml_config(yaml_file) task_names.append(config) else: tasks_list = args.tasks.split(",") task_names = pattern_match(tasks_list, ALL_TASKS) for task in [task for task in tasks_list if task not in task_names]: if os.path.isfile(task): config = utils.load_yaml_config(task) task_names.append(config) eval_logger.info(f"Selected Tasks: {task_names}") results = evaluator.simple_evaluate( model=args.model, model_args=args.model_args, tasks=task_names, num_fewshot=args.num_fewshot, batch_size=args.batch_size, device=args.device, limit=args.limit, decontamination_ngrams_path=args.decontamination_ngrams_path, check_integrity=args.check_integrity, ) if results is not None: dumped = json.dumps(results, indent=2) print(dumped) if args.output_path: with open(args.output_path, "w") as f: f.write(dumped) print( f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}" ) print(evaluator.make_table(results)) if __name__ == "__main__": main()