Commit fe9f0d46 authored by lintangsutawika's avatar lintangsutawika
Browse files

more comprehensive way to index tasks and groups

parent 711eddcf
import os
import yaml
import collections
from typing import List, Union, Dict
from lm_eval import utils
......@@ -224,22 +225,47 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st
def get_task_and_group(task_dir: str):
task_list = {}
group_list = {}
tasks_and_groups = collections.defaultdict()
for root, _, file_list in os.walk(task_dir):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
config = utils.simple_load_yaml_config(yaml_path)
if "task" in config and isinstance(config["task"], str):
task_list[config["task"]] = yaml_path
if "group" in config:
if isinstance(config["group"], str):
group_list[config["group"]] = yaml_path
elif isinstance(config["group"], list):
for group in config["group"]:
group_list[group] = yaml_path
return task_list, group_list
if list(config.keys()) == ["group", "task"]:
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}
else:
# This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
if "group" in config:
groups = config["group"]
if isinstance(config["group"], str):
groups = [groups]
for group in groups:
if group not in tasks_and_groups:
tasks_and_groups[group] = {
"type": "group",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[group]["task"].append(task)
return tasks_and_groups
def initialize_tasks(verbosity="INFO", include_path=None):
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
......@@ -249,14 +275,12 @@ def initialize_tasks(verbosity="INFO", include_path=None):
include_path = [include_path]
all_paths.extend(include_path)
task_list = {}
group_list = {}
ALL_TASKS = {}
for task_dir in all_paths:
tasks, groups = get_task_and_group(task_dir)
task_list = {**tasks, **task_list}
group_list = {**groups, **group_list}
tasks = get_task_and_group(task_dir)
ALL_TASKS = {**tasks, **ALL_TASKS}
return task_list, group_list
return ALL_TASKS
def get_task(task_name, config):
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment