Commit 0aca6958 authored by Baber's avatar Baber
Browse files

refactor: replace ConfigurableGroup with GroupConfig

parent 7fcfb4ac
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, Callable, List, Optional, Union from typing import Callable, List, Optional, Union
@dataclass @dataclass
...@@ -22,7 +22,7 @@ class AggMetricConfig(dict): ...@@ -22,7 +22,7 @@ class AggMetricConfig(dict):
@dataclass @dataclass
class GroupConfig(dict): class GroupConfig:
group: Optional[str] = None group: Optional[str] = None
group_alias: Optional[str] = None group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None task: Optional[Union[str, list]] = None
...@@ -39,6 +39,24 @@ class GroupConfig(dict): ...@@ -39,6 +39,24 @@ class GroupConfig(dict):
def __setitem__(self, item, value): def __setitem__(self, item, value):
return setattr(self, item, value) return setattr(self, item, value)
def __contains__(self, item):
"""Support 'in' operator for dict-like behavior."""
return hasattr(self, item)
def get(self, key, default=None):
"""Dict-like get method."""
return getattr(self, key, default)
def __hash__(self):
"""Make GroupConfig hashable based on group name."""
return hash(self.group)
def __eq__(self, other):
"""Equality comparison based on group name."""
if not isinstance(other, GroupConfig):
return False
return self.group == other.group
def __post_init__(self): def __post_init__(self):
if self.aggregate_metric_list is not None: if self.aggregate_metric_list is not None:
if isinstance(self.aggregate_metric_list, dict): if isinstance(self.aggregate_metric_list, dict):
...@@ -87,34 +105,5 @@ class GroupConfig(dict): ...@@ -87,34 +105,5 @@ class GroupConfig(dict):
"""Returns the version of the group configuration.""" """Returns the version of the group configuration."""
return self.metadata.get("version", "1.0") return self.metadata.get("version", "1.0")
@dataclass
class ConfigurableGroup:
def __init__(
self,
config: Optional[dict] = None,
) -> None:
self._config = GroupConfig(**config)
@property
def group(self):
return self._config.group
@property
def group_alias(self):
return self._config.group_alias
@property
def version(self):
return self._config.version
@property
def config(self):
return self._config.to_dict()
@property
def group_name(self) -> Any:
return self._config.group
def __repr__(self): def __repr__(self):
return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})" return f"GroupConfig(group={self.group},group_alias={self.group_alias})"
...@@ -151,14 +151,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]: ...@@ -151,14 +151,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
def get_subtask_list(task_dict, task_root=None, depth=0): def get_subtask_list(task_dict, task_root=None, depth=0):
from lm_eval.api.group import ConfigurableGroup from lm_eval.api.group import GroupConfig
from lm_eval.api.task import Task from lm_eval.api.task import Task
subtask_list = {} subtask_list = {}
for group_obj, task_obj in task_dict.items(): for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup): if isinstance(group_obj, GroupConfig):
# group_name = group_obj.group_name # group_name = group_obj.group
group_name = group_obj.group_name group_name = group_obj.group
else: else:
group_name = group_obj group_name = group_obj
if isinstance(task_obj, dict): if isinstance(task_obj, dict):
...@@ -176,9 +176,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0): ...@@ -176,9 +176,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list = {**subtask_list, **_subtask_list} subtask_list = {**subtask_list, **_subtask_list}
else: else:
if isinstance(task_obj, ConfigurableGroup): if isinstance(task_obj, GroupConfig):
# group_or_task_name = task_obj.group_name # group_or_task_name = task_obj.group
group_or_task_name = task_obj.group_name group_or_task_name = task_obj.group
elif isinstance(task_obj, Task): elif isinstance(task_obj, Task):
# group_or_task_name = task_obj.task_name # group_or_task_name = task_obj.task_name
group_or_task_name = task_obj.task_name group_or_task_name = task_obj.task_name
...@@ -241,7 +241,7 @@ def prepare_print_tasks( ...@@ -241,7 +241,7 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
""" """
from lm_eval.api.group import ConfigurableGroup from lm_eval.api.group import GroupConfig
def _sort_task_dict(task_dict): def _sort_task_dict(task_dict):
""" """
...@@ -252,8 +252,8 @@ def prepare_print_tasks( ...@@ -252,8 +252,8 @@ def prepare_print_tasks(
return dict( return dict(
sorted( sorted(
task_dict.items(), task_dict.items(),
key=lambda item: item[0].group_name key=lambda item: item[0].group
if isinstance(item[0], ConfigurableGroup) if isinstance(item[0], GroupConfig)
else item[0], else item[0],
) )
) )
...@@ -263,9 +263,9 @@ def prepare_print_tasks( ...@@ -263,9 +263,9 @@ def prepare_print_tasks(
task_dict = _sort_task_dict(task_dict) task_dict = _sort_task_dict(task_dict)
for task_or_group_name, task_or_group_obj in task_dict.items(): for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else "" tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup): if isinstance(task_or_group_name, GroupConfig):
# string_name = task_or_group_name.group_name # string_name = task_or_group_name.group
name = task_or_group_name.group_name name = task_or_group_name.group
from_configurable_group = True from_configurable_group = True
task_or_group_obj = _sort_task_dict(task_or_group_obj) task_or_group_obj = _sort_task_dict(task_or_group_obj)
elif isinstance(task_or_group_name, str): elif isinstance(task_or_group_name, str):
...@@ -399,7 +399,7 @@ def consolidate_group_results( ...@@ -399,7 +399,7 @@ def consolidate_group_results(
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple. The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored. In the top-level invocation of this function, task_aggregation_list is ignored.
""" """
from lm_eval.api.group import ConfigurableGroup from lm_eval.api.group import GroupConfig
from lm_eval.api.task import Task from lm_eval.api.task import Task
if task_root is None: if task_root is None:
...@@ -410,9 +410,9 @@ def consolidate_group_results( ...@@ -410,9 +410,9 @@ def consolidate_group_results(
for group_or_task, group_or_task_info in task_dict.items(): for group_or_task, group_or_task_info in task_dict.items():
# Convert to string # Convert to string
if isinstance(group_or_task, ConfigurableGroup): if isinstance(group_or_task, GroupConfig):
group_config = group_or_task.config group_config = group_or_task.to_dict()
group_or_task = group_or_task.group_name group_or_task = group_or_task.group
else: else:
group_config = None group_config = None
...@@ -441,7 +441,7 @@ def consolidate_group_results( ...@@ -441,7 +441,7 @@ def consolidate_group_results(
) )
if (group_config is None) or ( if (group_config is None) or (
group_config["aggregate_metric_list"] is None group_config.get("aggregate_metric_list") is None
): ):
results[group_or_task][" "] = " " results[group_or_task][" "] = " "
continue continue
...@@ -450,7 +450,7 @@ def consolidate_group_results( ...@@ -450,7 +450,7 @@ def consolidate_group_results(
agg_metric_list = group_config["aggregate_metric_list"] agg_metric_list = group_config["aggregate_metric_list"]
show_group_table = show_group_table | bool( show_group_table = show_group_table | bool(
group_config["aggregate_metric_list"] group_config.get("aggregate_metric_list")
) )
task_list = _task_aggregation_list[group_or_task] task_list = _task_aggregation_list[group_or_task]
......
...@@ -49,7 +49,7 @@ from typing import ( ...@@ -49,7 +49,7 @@ from typing import (
import yaml import yaml
from yaml import YAMLError from yaml import YAMLError
from lm_eval.api.group import ConfigurableGroup, GroupConfig from lm_eval.api.group import GroupConfig
from lm_eval.evaluator_utils import get_subtask_list from lm_eval.evaluator_utils import get_subtask_list
from lm_eval.utils import pattern_match, setup_logging from lm_eval.utils import pattern_match, setup_logging
...@@ -767,17 +767,17 @@ class TaskManager: ...@@ -767,17 +767,17 @@ class TaskManager:
self, self,
cfg: dict, cfg: dict,
parent_name: str | None = None, parent_name: str | None = None,
) -> tuple[ConfigurableGroup, list[Union[str, dict]]]: ) -> tuple[GroupConfig, list[Union[str, dict]]]:
""" """
Build ConfigurableGroup and return (group_obj, subtask_names). Build GroupConfig and return (group_obj, subtask_names).
Resolves tag expansion. Resolves tag expansion.
""" """
if self.metadata is not None: if self.metadata is not None:
cfg["metadata"] = cfg.get("metadata", {}) | self.metadata cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
grp = ConfigurableGroup(config=cfg) grp = GroupConfig(**cfg)
subtasks: list[Union[str, dict]] = [] subtasks: list[Union[str, dict]] = []
for t in grp.config["task"]: for t in grp.task:
if isinstance(t, str) and self._name_is_tag(t): if isinstance(t, str) and self._name_is_tag(t):
subtasks.extend(self._get_tasklist(t)) subtasks.extend(self._get_tasklist(t))
else: else:
...@@ -787,7 +787,7 @@ class TaskManager: ...@@ -787,7 +787,7 @@ class TaskManager:
def _load_subtasks( def _load_subtasks(
self, self,
subtasks: list[Union[str, dict]], subtasks: list[Union[str, dict]],
parent_name: Union[str, ConfigurableGroup, None], parent_name: Union[str, GroupConfig, None],
update_config: dict | None, update_config: dict | None,
) -> Mapping: ) -> Mapping:
"""Return merged mapping of all subtasks, handling duplicates.""" """Return merged mapping of all subtasks, handling duplicates."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment