"src/lib/apis/models/index.ts" did not exist on "21c7f507908fe0bd49baf9575f35f16a95aa3f2f"
Commit 73c8c87c authored by lintangsutawika's avatar lintangsutawika
Browse files

further tidy up

parent 52e0148a
...@@ -3,7 +3,7 @@ import abc ...@@ -3,7 +3,7 @@ import abc
import yaml import yaml
import collections import collections
from functools import partial from functools import partial, lru_cache
from typing import List, Union, Dict from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
...@@ -58,7 +58,7 @@ class TaskManager(abc.ABC): ...@@ -58,7 +58,7 @@ class TaskManager(abc.ABC):
ALL_TASKS = {} ALL_TASKS = {}
for task_dir in all_paths: for task_dir in all_paths:
tasks = get_task_and_group(task_dir) tasks = self._get_task_and_group(task_dir)
ALL_TASKS = {**tasks, **ALL_TASKS} ALL_TASKS = {**tasks, **ALL_TASKS}
return ALL_TASKS return ALL_TASKS
...@@ -81,25 +81,33 @@ class TaskManager(abc.ABC): ...@@ -81,25 +81,33 @@ class TaskManager(abc.ABC):
return False return False
return True return True
def _get_yaml_path(self, name):
assert name in self.ALL_TASKS
return self.ALL_TASKS[name]["yaml_path"]
def _get_config(self, name): def _get_config(self, name):
assert name in self.ALL_TASKS assert name in self.ALL_TASKS
yaml_path = self.ALL_TASKS[name]["yaml_path"] yaml_path = self._get_yaml_path(name)
return utils.load_yaml_config(yaml_path) return utils.load_yaml_config(yaml_path)
def _get_tasklist(self, name): def _get_tasklist(self, name):
assert self._name_is_task(name) == False assert self._name_is_task(name) == False
return self.ALL_TASKS[name]["task"] return self.ALL_TASKS[name]["task"]
@lru_cache(None)
def _load_individual_task_or_group(self, name_or_config: Union[str, dict] = None, parent_name: str = None) -> ConfigurableTask: def _load_individual_task_or_group(self, name_or_config: Union[str, dict] = None, parent_name: str = None) -> ConfigurableTask:
def load_task(config, task, group=None):
task_object = ConfigurableTask(config=config)
if group is not None:
task_object = (group, task_object)
return {task: task_object}
print("Loading", name_or_config) print("Loading", name_or_config)
if isinstance(name_or_config, str): if isinstance(name_or_config, str):
if self._name_is_task(name_or_config): if self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config) task_config = self._get_config(name_or_config)
task_object = ConfigurableTask(config=task_config) return load_task(task_config, task=name_or_config, group=parent_name)
if parent_name is not None:
task_object = (parent_name, task_object)
return {name_or_config: task_object}
else: else:
group_name = name_or_config group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config) subtask_list = self._get_tasklist(name_or_config)
...@@ -117,16 +125,18 @@ class TaskManager(abc.ABC): ...@@ -117,16 +125,18 @@ class TaskManager(abc.ABC):
} }
else: else:
task_config = name_or_config task_config = name_or_config
task_object = ConfigurableTask(config=task_config) return load_task(task_config, task=name_or_config, group=parent_name)
if parent_name is not None:
task_object = (parent_name, task_object)
return {task_name: task_object}
else: else:
group_name = name_or_config["group"] group_name = name_or_config["group"]
subtask_list = name_or_config["task"] subtask_list = name_or_config["task"]
if self._get_yaml_path(group_name) == -1:
all_subtasks = {group_name: (parent_name, None)}
else:
all_subtasks = {}
fn = partial(self._load_individual_task_or_group, parent_name=group_name) fn = partial(self._load_individual_task_or_group, parent_name=group_name)
all_subtasks = dict(collections.ChainMap(*map(fn, subtask_list))) all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))}
return all_subtasks return all_subtasks
...@@ -135,23 +145,58 @@ class TaskManager(abc.ABC): ...@@ -135,23 +145,58 @@ class TaskManager(abc.ABC):
if isinstance(task_list, str): if isinstance(task_list, str):
task_list = [task_list] task_list = [task_list]
all_loaded_tasks = {} all_loaded_tasks = dict(
for task in task_list: collections.ChainMap(
task_object = self._load_individual_task_or_group( *map(
name_or_config=task, self._load_individual_task_or_group,
task_list
)
)
) )
if isinstance(task, str): return all_loaded_tasks
task_name = task
elif isinstance(task, dict):
task_name = task["task"]
if isinstance(task_object, dict): def _get_task_and_group(self, task_dir: str):
all_loaded_tasks = {**task_object, **all_loaded_tasks} tasks_and_groups = collections.defaultdict()
for root, _, file_list in os.walk(task_dir):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
config = utils.simple_load_yaml_config(yaml_path)
if list(config.keys()) == ["group", "task"]:
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}
else: else:
all_loaded_tasks[task_name] = task_object # This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
return all_loaded_tasks if "group" in config:
groups = config["group"]
if isinstance(config["group"], str):
groups = [groups]
for group in groups:
if group not in tasks_and_groups:
tasks_and_groups[group] = {
"type": "group",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[group]["task"].append(task)
return tasks_and_groups
def register_configurable_task(config: Dict[str, str]) -> int: def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type( SubClass = type(
...@@ -178,68 +223,6 @@ def register_configurable_task(config: Dict[str, str]) -> int: ...@@ -178,68 +223,6 @@ def register_configurable_task(config: Dict[str, str]) -> int:
return 0 return 0
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
group = config["group"]
task_config_list = []
group_config_list = []
registered_task_or_group_list = []
for task in config["task"]:
if isinstance(task, str):
registered_task_or_group_list.append(task)
elif is_group(task):
group_config_list.append(task)
else:
task_config_list.append(task)
for task_config in task_config_list:
base_config = {}
task_name_config = {}
if "task" in task_config:
task_name = task_config["task"]
if task_name in TASK_REGISTRY:
task_obj = get_task_dict(task_name)[task_name]
if type(task_obj) == tuple:
_, task_obj = task_obj
if task_obj is not None:
base_config = task_obj._config.to_dict(keep_callable=True)
task_name_config["task"] = f"{group}_{task_name}"
task_config = utils.load_yaml_config(yaml_path, task_config)
var_configs = check_prompt_config(
{
**base_config,
**task_config,
**{"group": group},
**task_name_config,
},
yaml_path=os.path.dirname(yaml_path),
)
for config in var_configs:
register_configurable_task(config)
for group_config in group_config_list:
sub_group = group_config["group"]
register_configurable_group(group_config, yaml_path)
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(sub_group)
else:
GROUP_REGISTRY[group] = [sub_group]
self.ALL_TASKS.add(group)
task_names = utils.pattern_match(registered_task_or_group_list, self.ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
self.ALL_TASKS.add(group)
return 0
def check_prompt_config( def check_prompt_config(
config: Dict[str, str], yaml_path: str = None config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
...@@ -338,48 +321,7 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st ...@@ -338,48 +321,7 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st
return 0 return 0
def get_task_and_group(task_dir: str):
tasks_and_groups = collections.defaultdict()
for root, _, file_list in os.walk(task_dir):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
config = utils.simple_load_yaml_config(yaml_path)
if list(config.keys()) == ["group", "task"]:
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}
else:
# This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
if "group" in config:
groups = config["group"]
if isinstance(config["group"], str):
groups = [groups]
for group in groups:
if group not in tasks_and_groups:
tasks_and_groups[group] = {
"type": "group",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[group]["task"].append(task)
return tasks_and_groups
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment