Commit 47756947 authored by lintangsutawika's avatar lintangsutawika
Browse files

more comprehensive way to index tasks and groups

parent 49be8ea2
import os import os
import yaml import yaml
import collections
from typing import List, Union, Dict from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
...@@ -224,22 +225,47 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st ...@@ -224,22 +225,47 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st
def get_task_and_group(task_dir: str): def get_task_and_group(task_dir: str):
task_list = {} tasks_and_groups = collections.defaultdict()
group_list = {}
for root, _, file_list in os.walk(task_dir): for root, _, file_list in os.walk(task_dir):
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
config = utils.simple_load_yaml_config(yaml_path) config = utils.simple_load_yaml_config(yaml_path)
if "task" in config and isinstance(config["task"], str): if list(config.keys()) == ["group", "task"]:
task_list[config["task"]] = yaml_path # This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}
else:
# This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
if "group" in config: if "group" in config:
groups = config["group"]
if isinstance(config["group"], str): if isinstance(config["group"], str):
group_list[config["group"]] = yaml_path groups = [groups]
elif isinstance(config["group"], list):
for group in config["group"]: for group in groups:
group_list[group] = yaml_path if group not in tasks_and_groups:
return task_list, group_list tasks_and_groups[group] = {
"type": "group",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[group]["task"].append(task)
return tasks_and_groups
def initialize_tasks(verbosity="INFO", include_path=None): def initialize_tasks(verbosity="INFO", include_path=None):
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
...@@ -249,14 +275,12 @@ def initialize_tasks(verbosity="INFO", include_path=None): ...@@ -249,14 +275,12 @@ def initialize_tasks(verbosity="INFO", include_path=None):
include_path = [include_path] include_path = [include_path]
all_paths.extend(include_path) all_paths.extend(include_path)
task_list = {} ALL_TASKS = {}
group_list = {}
for task_dir in all_paths: for task_dir in all_paths:
tasks, groups = get_task_and_group(task_dir) tasks = get_task_and_group(task_dir)
task_list = {**tasks, **task_list} ALL_TASKS = {**tasks, **ALL_TASKS}
group_list = {**groups, **group_list}
return task_list, group_list return ALL_TASKS
def get_task(task_name, config): def get_task(task_name, config):
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment