Commit 5dfc887d authored by lintangsutawika's avatar lintangsutawika
Browse files

simplified code

parent 9c865819
...@@ -2,6 +2,8 @@ import os ...@@ -2,6 +2,8 @@ import os
import abc import abc
import yaml import yaml
import collections import collections
from functools import partial
from typing import List, Union, Dict from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
...@@ -30,15 +32,10 @@ from .scrolls.task import ( ...@@ -30,15 +32,10 @@ from .scrolls.task import (
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
def is_group(task):
if list(task.keys()) == ["group", "task"]:
return True
return False
class TaskManager(abc.ABC): class TaskManager(abc.ABC):
def __init__( def __init__(
self, self,
verbosity="INFO", verbosity="INFO",
include_path=None include_path=None
) -> None: ) -> None:
...@@ -52,7 +49,7 @@ class TaskManager(abc.ABC): ...@@ -52,7 +49,7 @@ class TaskManager(abc.ABC):
) )
def initialize_tasks(self, include_path=None): def initialize_tasks(self, include_path=None):
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"] all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_path is not None: if include_path is not None:
if isinstance(include_path, str): if isinstance(include_path, str):
...@@ -69,78 +66,69 @@ class TaskManager(abc.ABC): ...@@ -69,78 +66,69 @@ class TaskManager(abc.ABC):
def all_tasks(self): def all_tasks(self):
return sorted(list(self.ALL_TASKS.keys())) return sorted(list(self.ALL_TASKS.keys()))
def _load_individual_task_or_group(self, task_name_or_config: Union[str, dict] = None) -> ConfigurableTask: def _name_is_registered(self, name):
if name in self.ALL_TASKS:
print("Loading", task_name_or_config) return True
if isinstance(task_name_or_config, str): return False
task_info = self.ALL_TASKS[task_name_or_config]
yaml_path = task_info["yaml_path"] def _name_is_task(self, name):
task_type = task_info["type"] if self.ALL_TASKS[name]["type"] == "task":
subtask_list = task_info["task"] if "task" in task_info else -1 return True
if task_type == "task": return False
task_config = utils.load_yaml_config(yaml_path)
return ConfigurableTask(config=task_config) def _config_is_task(self, config):
if list(config.keys()) == ["group", "task"]:
return False
return True
def _get_config(self, name):
assert name in self.ALL_TASKS
yaml_path = self.ALL_TASKS[name]["yaml_path"]
return utils.load_yaml_config(yaml_path)
def _get_tasklist(self, name):
assert self._name_is_task(name) == False
return self.ALL_TASKS[name]["task"]
def _load_individual_task_or_group(self, name_or_config: Union[str, dict] = None, parent_name: str = None) -> ConfigurableTask:
print("Loading", name_or_config)
if isinstance(name_or_config, str):
if self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config)
task_object = ConfigurableTask(config=task_config)
if parent_name is not None:
task_object = (parent_name, task_object)
return {name_or_config: task_object}
else: else:
group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1: if subtask_list == -1:
task_config = utils.load_yaml_config(yaml_path) subtask_list = self._get_config(name_or_config)["task"]
group_name = task_config["group"]
subtask_list = task_config["task"] elif isinstance(name_or_config, dict):
if self._config_is_task(name_or_config):
task_name = name_or_config["task"]
if self._name_is_registered(task_name):
base_task_config = self._get_config(task_name)
task_config={
**base_task_config,
**name_or_config,
}
else: else:
group_name = task_name_or_config task_config = name_or_config
task_object = ConfigurableTask(config=task_config)
all_subtasks = {} if parent_name is not None:
for task_or_config in subtask_list: task_object = (parent_name, task_object)
if isinstance(task_or_config, str): return {task_name: task_object}
all_subtasks[task_or_config] = (group_name, None)
task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
elif isinstance(task_or_config, dict):
if "group" in task_or_config:
all_subtasks[task_or_config["group"]] = (group_name, None)
elif "task" in task_or_config:
all_subtasks[task_or_config["task"]] = (group_name, None)
task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
if isinstance(task_object, dict):
all_subtasks = {**task_object, **all_subtasks}
else:
task_name = task_object._config["task"]
all_subtasks[task_name] = (group_name, task_object)
# if group_name is not None:
# all_subtasks[task_name] = (group_name, task_object)
# else:
# all_subtasks[task_name] = task_object
return all_subtasks
elif isinstance(task_name_or_config, dict):
if is_group(task_name_or_config):
group_name = task_name_or_config["group"]
subtask_list = task_name_or_config["task"]
all_subtasks = {}
for task_or_config in subtask_list:
if isinstance(task_or_config, str):
task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
task_name = task_or_config
elif isinstance(task_or_config, dict):
task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
if isinstance(task_object, dict):
all_subtasks = {**task_object, **all_subtasks}
else:
task_name = task_object._config["task"]
all_subtasks[task_name] = (group_name, task_object)
return all_subtasks
else: else:
task_type = "task" group_name = name_or_config["group"]
task_name = task_name_or_config["task"] subtask_list = name_or_config["task"]
base_task_info = self.ALL_TASKS[task_name]
base_yaml_path = base_task_info["yaml_path"] fn = partial(self._load_individual_task_or_group, parent_name=group_name)
base_task_config = utils.load_yaml_config(base_yaml_path) all_subtasks = dict(collections.ChainMap(*map(fn, subtask_list)))
return all_subtasks
return ConfigurableTask(
config={
**base_task_config,
**task_name_or_config,
}
)
def load_task_or_group(self, task_list: Union[str, list] = None) -> dict: def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
...@@ -150,7 +138,7 @@ class TaskManager(abc.ABC): ...@@ -150,7 +138,7 @@ class TaskManager(abc.ABC):
all_loaded_tasks = {} all_loaded_tasks = {}
for task in task_list: for task in task_list:
task_object = self._load_individual_task_or_group( task_object = self._load_individual_task_or_group(
task_name_or_config=task, name_or_config=task,
) )
if isinstance(task, str): if isinstance(task, str):
task_name = task task_name = task
...@@ -161,7 +149,7 @@ class TaskManager(abc.ABC): ...@@ -161,7 +149,7 @@ class TaskManager(abc.ABC):
all_loaded_tasks = {**task_object, **all_loaded_tasks} all_loaded_tasks = {**task_object, **all_loaded_tasks}
else: else:
all_loaded_tasks[task_name] = task_object all_loaded_tasks[task_name] = task_object
return all_loaded_tasks return all_loaded_tasks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment