Commit 5dfc887d authored by lintangsutawika's avatar lintangsutawika
Browse files

simplified code

parent 9c865819
......@@ -2,6 +2,8 @@ import os
import abc
import yaml
import collections
from functools import partial
from typing import List, Union, Dict
from lm_eval import utils
......@@ -30,15 +32,10 @@ from .scrolls.task import (
eval_logger = utils.eval_logger
def is_group(task):
if list(task.keys()) == ["group", "task"]:
return True
return False
class TaskManager(abc.ABC):
def __init__(
self,
self,
verbosity="INFO",
include_path=None
) -> None:
......@@ -52,7 +49,7 @@ class TaskManager(abc.ABC):
)
def initialize_tasks(self, include_path=None):
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_path is not None:
if isinstance(include_path, str):
......@@ -69,78 +66,69 @@ class TaskManager(abc.ABC):
def all_tasks(self):
return sorted(list(self.ALL_TASKS.keys()))
def _load_individual_task_or_group(self, task_name_or_config: Union[str, dict] = None) -> ConfigurableTask:
print("Loading", task_name_or_config)
if isinstance(task_name_or_config, str):
task_info = self.ALL_TASKS[task_name_or_config]
yaml_path = task_info["yaml_path"]
task_type = task_info["type"]
subtask_list = task_info["task"] if "task" in task_info else -1
if task_type == "task":
task_config = utils.load_yaml_config(yaml_path)
return ConfigurableTask(config=task_config)
def _name_is_registered(self, name):
if name in self.ALL_TASKS:
return True
return False
def _name_is_task(self, name):
if self.ALL_TASKS[name]["type"] == "task":
return True
return False
def _config_is_task(self, config):
if list(config.keys()) == ["group", "task"]:
return False
return True
def _get_config(self, name):
assert name in self.ALL_TASKS
yaml_path = self.ALL_TASKS[name]["yaml_path"]
return utils.load_yaml_config(yaml_path)
def _get_tasklist(self, name):
assert self._name_is_task(name) == False
return self.ALL_TASKS[name]["task"]
def _load_individual_task_or_group(self, name_or_config: Union[str, dict] = None, parent_name: str = None) -> ConfigurableTask:
print("Loading", name_or_config)
if isinstance(name_or_config, str):
if self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config)
task_object = ConfigurableTask(config=task_config)
if parent_name is not None:
task_object = (parent_name, task_object)
return {name_or_config: task_object}
else:
group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1:
task_config = utils.load_yaml_config(yaml_path)
group_name = task_config["group"]
subtask_list = task_config["task"]
subtask_list = self._get_config(name_or_config)["task"]
elif isinstance(name_or_config, dict):
if self._config_is_task(name_or_config):
task_name = name_or_config["task"]
if self._name_is_registered(task_name):
base_task_config = self._get_config(task_name)
task_config={
**base_task_config,
**name_or_config,
}
else:
group_name = task_name_or_config
all_subtasks = {}
for task_or_config in subtask_list:
if isinstance(task_or_config, str):
all_subtasks[task_or_config] = (group_name, None)
task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
elif isinstance(task_or_config, dict):
if "group" in task_or_config:
all_subtasks[task_or_config["group"]] = (group_name, None)
elif "task" in task_or_config:
all_subtasks[task_or_config["task"]] = (group_name, None)
task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
if isinstance(task_object, dict):
all_subtasks = {**task_object, **all_subtasks}
else:
task_name = task_object._config["task"]
all_subtasks[task_name] = (group_name, task_object)
# if group_name is not None:
# all_subtasks[task_name] = (group_name, task_object)
# else:
# all_subtasks[task_name] = task_object
return all_subtasks
elif isinstance(task_name_or_config, dict):
if is_group(task_name_or_config):
group_name = task_name_or_config["group"]
subtask_list = task_name_or_config["task"]
all_subtasks = {}
for task_or_config in subtask_list:
if isinstance(task_or_config, str):
task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
task_name = task_or_config
elif isinstance(task_or_config, dict):
task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
if isinstance(task_object, dict):
all_subtasks = {**task_object, **all_subtasks}
else:
task_name = task_object._config["task"]
all_subtasks[task_name] = (group_name, task_object)
return all_subtasks
task_config = name_or_config
task_object = ConfigurableTask(config=task_config)
if parent_name is not None:
task_object = (parent_name, task_object)
return {task_name: task_object}
else:
task_type = "task"
task_name = task_name_or_config["task"]
base_task_info = self.ALL_TASKS[task_name]
base_yaml_path = base_task_info["yaml_path"]
base_task_config = utils.load_yaml_config(base_yaml_path)
return ConfigurableTask(
config={
**base_task_config,
**task_name_or_config,
}
)
group_name = name_or_config["group"]
subtask_list = name_or_config["task"]
fn = partial(self._load_individual_task_or_group, parent_name=group_name)
all_subtasks = dict(collections.ChainMap(*map(fn, subtask_list)))
return all_subtasks
def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
......@@ -150,7 +138,7 @@ class TaskManager(abc.ABC):
all_loaded_tasks = {}
for task in task_list:
task_object = self._load_individual_task_or_group(
task_name_or_config=task,
name_or_config=task,
)
if isinstance(task, str):
task_name = task
......@@ -161,7 +149,7 @@ class TaskManager(abc.ABC):
all_loaded_tasks = {**task_object, **all_loaded_tasks}
else:
all_loaded_tasks[task_name] = task_object
return all_loaded_tasks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment