Commit 3473e196 authored by lintangsutawika's avatar lintangsutawika
Browse files

adjust group to also be a configurable group

parent 96336194
...@@ -5,7 +5,9 @@ from functools import partial ...@@ -5,7 +5,9 @@ from functools import partial
from typing import Dict, List, Mapping, Optional, Union from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils from lm_eval import utils
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import ConfigurableTask, ConfigurableGroup, GroupConfig, Task
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
class TaskManager: class TaskManager:
...@@ -132,13 +134,11 @@ class TaskManager: ...@@ -132,13 +134,11 @@ class TaskManager:
update_config: Optional[dict] = None, update_config: Optional[dict] = None,
yaml_path: Optional[str] = None, yaml_path: Optional[str] = None,
) -> Mapping: ) -> Mapping:
def load_task(config, task, group=None, yaml_path=None): def load_task(config, task):
if "include" in config: if "include" in config:
if yaml_path is None:
raise ValueError
config = { config = {
**utils.load_yaml_config( **utils.load_yaml_config(
yaml_path, yaml_path=None,
yaml_config={"include": config.pop("include")}, yaml_config={"include": config.pop("include")},
mode="full", mode="full",
), ),
...@@ -147,10 +147,7 @@ class TaskManager: ...@@ -147,10 +147,7 @@ class TaskManager:
if self._config_is_python_task(config): if self._config_is_python_task(config):
task_object = config["class"]() task_object = config["class"]()
else: else:
config = self._process_alias(config, group=group)
task_object = ConfigurableTask(config=config) task_object = ConfigurableTask(config=config)
if group is not None:
task_object = (group, task_object)
return {task: task_object} return {task: task_object}
if isinstance(name_or_config, str): if isinstance(name_or_config, str):
...@@ -159,7 +156,7 @@ class TaskManager: ...@@ -159,7 +156,7 @@ class TaskManager:
name_or_config = {"task": name_or_config, **update_config} name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config): elif self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config) task_config = self._get_config(name_or_config)
return load_task(task_config, task=name_or_config, group=parent_name) return load_task(task_config, task=name_or_config)
else: else:
group_name = name_or_config group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config) subtask_list = self._get_tasklist(name_or_config)
...@@ -167,19 +164,7 @@ class TaskManager: ...@@ -167,19 +164,7 @@ class TaskManager:
group_config = self._get_config(name_or_config) group_config = self._get_config(name_or_config)
subtask_list = group_config["task"] subtask_list = group_config["task"]
# This checks if we're at the root. group_name = ConfigurableGroup(config=group_config)
if parent_name is None:
group_config = self._get_config(name_or_config)
if set(group_config.keys()) > {"task", "group"}:
update_config = {
k: v
for k, v in group_config.items()
if k not in ["task", "group"]
}
yaml_path = self._get_yaml_path(group_name)
if (update_config is not None) and ("group_alias" in update_config):
update_config.pop("group_alias")
if isinstance(name_or_config, dict): if isinstance(name_or_config, dict):
if update_config is not None: if update_config is not None:
...@@ -225,9 +210,7 @@ class TaskManager: ...@@ -225,9 +210,7 @@ class TaskManager:
} }
else: else:
task_config = name_or_config task_config = name_or_config
return load_task( return load_task(task_config, task=name)
task_config, task=name, group=parent_name, yaml_path=yaml_path
)
else: else:
group_name = name_or_config["group"] group_name = name_or_config["group"]
subtask_list = name_or_config["task"] subtask_list = name_or_config["task"]
...@@ -235,15 +218,9 @@ class TaskManager: ...@@ -235,15 +218,9 @@ class TaskManager:
update_config = { update_config = {
k: v k: v
for k, v in name_or_config.items() for k, v in name_or_config.items()
if k not in ["task", "group"] if k not in GROUP_ONLY_KEYS
} }
group_name = ConfigurableGroup(config=name_or_config)
all_subtasks = {}
if parent_name is not None:
parent_group_config = self._get_config(parent_name)
if "group_alias" in parent_group_config:
parent_name = parent_group_config["group_alias"]
all_subtasks = {group_name: (parent_name, parent_group_config)}
fn = partial( fn = partial(
self._load_individual_task_or_group, self._load_individual_task_or_group,
...@@ -251,11 +228,7 @@ class TaskManager: ...@@ -251,11 +228,7 @@ class TaskManager:
update_config=update_config, update_config=update_config,
yaml_path=yaml_path, yaml_path=yaml_path,
) )
all_subtasks = { return {group_name: dict(collections.ChainMap(*map(fn, subtask_list)))}
**all_subtasks,
**dict(collections.ChainMap(*map(fn, subtask_list))),
}
return all_subtasks
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment