Commit 93cdc5bf authored by lintangsutawika's avatar lintangsutawika
Browse files

adjust how task objects are called

parent 0d6d1d7c
......@@ -168,7 +168,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
import sys; sys.exit()
elif args.tasks == "list":
eval_logger.info(
"Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS)))
"Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS.keys())))
)
else:
if os.path.isdir(args.tasks):
......@@ -181,7 +181,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
task_names.append(config)
else:
tasks_list = args.tasks.split(",")
task_names = utils.pattern_match(tasks_list, ALL_TASKS)
task_names = utils.pattern_match(tasks_list, ALL_TASKS.keys())
for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
......@@ -225,8 +225,18 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Selected Tasks: {task_names}")
eval_logger.info("Loading selected tasks...")
task_objects = {}
for task in task_names:
task_objects = load_task_or_group(ALL_TASKS[task])
if isinstance(task, str):
task_objects[task] = load_task_or_group(
ALL_TASKS,
task_name=task,
)
elif isinstance(task, dict):
task_objects[task["task"]] = load_task_or_group(
ALL_TASKS,
task_config=task,
)
results = evaluator.simple_evaluate(
model=args.model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment