Commit 3473e196 authored by lintangsutawika's avatar lintangsutawika
Browse files

adjust group to also be a configurable group

parent 96336194
......@@ -5,7 +5,9 @@ from functools import partial
from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.api.task import ConfigurableTask, ConfigurableGroup, GroupConfig, Task
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
class TaskManager:
......@@ -132,13 +134,11 @@ class TaskManager:
update_config: Optional[dict] = None,
yaml_path: Optional[str] = None,
) -> Mapping:
def load_task(config, task, group=None, yaml_path=None):
def load_task(config, task):
if "include" in config:
if yaml_path is None:
raise ValueError
config = {
**utils.load_yaml_config(
yaml_path,
yaml_path=None,
yaml_config={"include": config.pop("include")},
mode="full",
),
......@@ -147,10 +147,7 @@ class TaskManager:
if self._config_is_python_task(config):
task_object = config["class"]()
else:
config = self._process_alias(config, group=group)
task_object = ConfigurableTask(config=config)
if group is not None:
task_object = (group, task_object)
return {task: task_object}
if isinstance(name_or_config, str):
......@@ -159,7 +156,7 @@ class TaskManager:
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config)
return load_task(task_config, task=name_or_config, group=parent_name)
return load_task(task_config, task=name_or_config)
else:
group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config)
......@@ -167,19 +164,7 @@ class TaskManager:
group_config = self._get_config(name_or_config)
subtask_list = group_config["task"]
# This checks if we're at the root.
if parent_name is None:
group_config = self._get_config(name_or_config)
if set(group_config.keys()) > {"task", "group"}:
update_config = {
k: v
for k, v in group_config.items()
if k not in ["task", "group"]
}
yaml_path = self._get_yaml_path(group_name)
if (update_config is not None) and ("group_alias" in update_config):
update_config.pop("group_alias")
group_name = ConfigurableGroup(config=group_config)
if isinstance(name_or_config, dict):
if update_config is not None:
......@@ -225,9 +210,7 @@ class TaskManager:
}
else:
task_config = name_or_config
return load_task(
task_config, task=name, group=parent_name, yaml_path=yaml_path
)
return load_task(task_config, task=name)
else:
group_name = name_or_config["group"]
subtask_list = name_or_config["task"]
......@@ -235,15 +218,9 @@ class TaskManager:
update_config = {
k: v
for k, v in name_or_config.items()
if k not in ["task", "group"]
if k not in GROUP_ONLY_KEYS
}
all_subtasks = {}
if parent_name is not None:
parent_group_config = self._get_config(parent_name)
if "group_alias" in parent_group_config:
parent_name = parent_group_config["group_alias"]
all_subtasks = {group_name: (parent_name, parent_group_config)}
group_name = ConfigurableGroup(config=name_or_config)
fn = partial(
self._load_individual_task_or_group,
......@@ -251,11 +228,7 @@ class TaskManager:
update_config=update_config,
yaml_path=yaml_path,
)
all_subtasks = {
**all_subtasks,
**dict(collections.ChainMap(*map(fn, subtask_list))),
}
return all_subtasks
return {group_name: dict(collections.ChainMap(*map(fn, subtask_list)))}
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment