Commit 47756947 authored by lintangsutawika's avatar lintangsutawika
Browse files

more comprehensive way to index tasks and groups

parent 49be8ea2
import os import os
import yaml import yaml
import collections
from typing import List, Union, Dict from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
...@@ -224,22 +225,47 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st ...@@ -224,22 +225,47 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st
def get_task_and_group(task_dir: str): def get_task_and_group(task_dir: str):
task_list = {} tasks_and_groups = collections.defaultdict()
group_list = {}
for root, _, file_list in os.walk(task_dir): for root, _, file_list in os.walk(task_dir):
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
config = utils.simple_load_yaml_config(yaml_path) config = utils.simple_load_yaml_config(yaml_path)
if "task" in config and isinstance(config["task"], str): if list(config.keys()) == ["group", "task"]:
task_list[config["task"]] = yaml_path # This is a group config
if "group" in config: tasks_and_groups[config["group"]] = {
if isinstance(config["group"], str): "type": "group",
group_list[config["group"]] = yaml_path "task": -1, # This signals that
elif isinstance(config["group"], list): # we don't need to know
for group in config["group"]: # the task list for indexing
group_list[group] = yaml_path # as it can be loaded
return task_list, group_list # when called.
"yaml_path": yaml_path,
}
else:
# This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
if "group" in config:
groups = config["group"]
if isinstance(config["group"], str):
groups = [groups]
for group in groups:
if group not in tasks_and_groups:
tasks_and_groups[group] = {
"type": "group",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[group]["task"].append(task)
return tasks_and_groups
def initialize_tasks(verbosity="INFO", include_path=None): def initialize_tasks(verbosity="INFO", include_path=None):
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
...@@ -249,14 +275,12 @@ def initialize_tasks(verbosity="INFO", include_path=None): ...@@ -249,14 +275,12 @@ def initialize_tasks(verbosity="INFO", include_path=None):
include_path = [include_path] include_path = [include_path]
all_paths.extend(include_path) all_paths.extend(include_path)
task_list = {} ALL_TASKS = {}
group_list = {}
for task_dir in all_paths: for task_dir in all_paths:
tasks, groups = get_task_and_group(task_dir) tasks = get_task_and_group(task_dir)
task_list = {**tasks, **task_list} ALL_TASKS = {**tasks, **ALL_TASKS}
group_list = {**groups, **group_list}
return task_list, group_list return ALL_TASKS
def get_task(task_name, config): def get_task(task_name, config):
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment