Commit 8299ab3b authored by lintangsutawika's avatar lintangsutawika
Browse files

process yaml and registered tasks through args.task

parent 41ec3bc4
......@@ -5,7 +5,7 @@ import fnmatch
import yaml
import os
from lm_eval import evaluator, tasks
from lm_eval import evaluator, utils
from lm_eval.tasks import ALL_TASKS
logging.getLogger("openai").setLevel(logging.WARNING)
......@@ -68,22 +68,38 @@ def main():
)
if args.tasks is None:
if args.config:
if os.path.isdir(args.tasks):
import glob
task_names = []
for config_files in args.config.split(","):
config = get_yaml_config(config_files)
if args.num_fewshot != 0:
config["num_fewshot"] = args.num_fewshot
if args.batch_size != None:
config["batch_size"] = args.batch_size
yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.get_yaml_config(yaml_file)
task_names.append(config)
else:
task_names = ALL_TASKS
else:
task_names = pattern_match(args.tasks.split(","), ALL_TASKS)
tasks_list = args.tasks.split(",")
task_names = pattern_match(tasks_list, ALL_TASKS)
for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task):
config = utils.get_yaml_config(config_files)
task_names.append(config)
# # Tas
# if args.config:
# task_names = []
# for config_files in args.config.split(","):
# config = utils.get_yaml_config(config_files)
# if args.num_fewshot != 0:
# config["num_fewshot"] = args.num_fewshot
# if args.batch_size != None:
# config["batch_size"] = args.batch_size
# task_names.append(config)
# else:
# task_names = ALL_TASKS
# else:
print(f"Selected Tasks: {task_names}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment