Commit 0aca6958 authored by Baber's avatar Baber
Browse files

refactor: replace ConfigurableGroup with GroupConfig

parent 7fcfb4ac
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Callable, List, Optional, Union
from typing import Callable, List, Optional, Union
@dataclass
......@@ -22,7 +22,7 @@ class AggMetricConfig(dict):
@dataclass
class GroupConfig(dict):
class GroupConfig:
group: Optional[str] = None
group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None
......@@ -39,6 +39,24 @@ class GroupConfig(dict):
def __setitem__(self, item, value):
return setattr(self, item, value)
def __contains__(self, item):
"""Support 'in' operator for dict-like behavior."""
return hasattr(self, item)
def get(self, key, default=None):
"""Dict-like get method."""
return getattr(self, key, default)
def __hash__(self):
"""Make GroupConfig hashable based on group name."""
return hash(self.group)
def __eq__(self, other):
"""Equality comparison based on group name."""
if not isinstance(other, GroupConfig):
return False
return self.group == other.group
def __post_init__(self):
if self.aggregate_metric_list is not None:
if isinstance(self.aggregate_metric_list, dict):
......@@ -87,34 +105,5 @@ class GroupConfig(dict):
"""Returns the version of the group configuration."""
return self.metadata.get("version", "1.0")
@dataclass
class ConfigurableGroup:
def __init__(
self,
config: Optional[dict] = None,
) -> None:
self._config = GroupConfig(**config)
@property
def group(self):
return self._config.group
@property
def group_alias(self):
return self._config.group_alias
@property
def version(self):
return self._config.version
@property
def config(self):
return self._config.to_dict()
@property
def group_name(self) -> Any:
return self._config.group
def __repr__(self):
return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
return f"GroupConfig(group={self.group},group_alias={self.group_alias})"
......@@ -151,14 +151,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
def get_subtask_list(task_dict, task_root=None, depth=0):
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import Task
subtask_list = {}
for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup):
# group_name = group_obj.group_name
group_name = group_obj.group_name
if isinstance(group_obj, GroupConfig):
# group_name = group_obj.group
group_name = group_obj.group
else:
group_name = group_obj
if isinstance(task_obj, dict):
......@@ -176,9 +176,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list = {**subtask_list, **_subtask_list}
else:
if isinstance(task_obj, ConfigurableGroup):
# group_or_task_name = task_obj.group_name
group_or_task_name = task_obj.group_name
if isinstance(task_obj, GroupConfig):
# group_or_task_name = task_obj.group
group_or_task_name = task_obj.group
elif isinstance(task_obj, Task):
# group_or_task_name = task_obj.task_name
group_or_task_name = task_obj.task_name
......@@ -241,7 +241,7 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.group import GroupConfig
def _sort_task_dict(task_dict):
"""
......@@ -252,8 +252,8 @@ def prepare_print_tasks(
return dict(
sorted(
task_dict.items(),
key=lambda item: item[0].group_name
if isinstance(item[0], ConfigurableGroup)
key=lambda item: item[0].group
if isinstance(item[0], GroupConfig)
else item[0],
)
)
......@@ -263,9 +263,9 @@ def prepare_print_tasks(
task_dict = _sort_task_dict(task_dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
# string_name = task_or_group_name.group_name
name = task_or_group_name.group_name
if isinstance(task_or_group_name, GroupConfig):
# string_name = task_or_group_name.group
name = task_or_group_name.group
from_configurable_group = True
task_or_group_obj = _sort_task_dict(task_or_group_obj)
elif isinstance(task_or_group_name, str):
......@@ -399,7 +399,7 @@ def consolidate_group_results(
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored.
"""
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import Task
if task_root is None:
......@@ -410,9 +410,9 @@ def consolidate_group_results(
for group_or_task, group_or_task_info in task_dict.items():
# Convert to string
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group_or_task = group_or_task.group_name
if isinstance(group_or_task, GroupConfig):
group_config = group_or_task.to_dict()
group_or_task = group_or_task.group
else:
group_config = None
......@@ -441,7 +441,7 @@ def consolidate_group_results(
)
if (group_config is None) or (
group_config["aggregate_metric_list"] is None
group_config.get("aggregate_metric_list") is None
):
results[group_or_task][" "] = " "
continue
......@@ -450,7 +450,7 @@ def consolidate_group_results(
agg_metric_list = group_config["aggregate_metric_list"]
show_group_table = show_group_table | bool(
group_config["aggregate_metric_list"]
group_config.get("aggregate_metric_list")
)
task_list = _task_aggregation_list[group_or_task]
......
......@@ -49,7 +49,7 @@ from typing import (
import yaml
from yaml import YAMLError
from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.api.group import GroupConfig
from lm_eval.evaluator_utils import get_subtask_list
from lm_eval.utils import pattern_match, setup_logging
......@@ -767,17 +767,17 @@ class TaskManager:
self,
cfg: dict,
parent_name: str | None = None,
) -> tuple[ConfigurableGroup, list[Union[str, dict]]]:
) -> tuple[GroupConfig, list[Union[str, dict]]]:
"""
Build ConfigurableGroup and return (group_obj, subtask_names).
Build GroupConfig and return (group_obj, subtask_names).
Resolves tag expansion.
"""
if self.metadata is not None:
cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
grp = ConfigurableGroup(config=cfg)
grp = GroupConfig(**cfg)
subtasks: list[Union[str, dict]] = []
for t in grp.config["task"]:
for t in grp.task:
if isinstance(t, str) and self._name_is_tag(t):
subtasks.extend(self._get_tasklist(t))
else:
......@@ -787,7 +787,7 @@ class TaskManager:
def _load_subtasks(
self,
subtasks: list[Union[str, dict]],
parent_name: Union[str, ConfigurableGroup, None],
parent_name: Union[str, GroupConfig, None],
update_config: dict | None,
) -> Mapping:
"""Return merged mapping of all subtasks, handling duplicates."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment