from dataclasses import asdict, dataclass, field from inspect import getsource from typing import Callable, Optional, Union @dataclass class AggMetricConfig(dict): metric: Optional[str] = None aggregation: Optional[str] = "mean" weight_by_size: bool = False # list of filter names which should be incorporated into the aggregated metric. filter_list: Optional[Union[str, list]] = "none" def __post_init__(self): if self.aggregation != "mean" and not callable(self.aggregation): raise ValueError( f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'." ) if isinstance(self.filter_list, str): self.filter_list = [self.filter_list] @dataclass class GroupConfig: group: Optional[str] = None group_alias: Optional[str] = None task: Union[str, list] = field(default_factory=list) aggregate_metric_list: Optional[ Union[list[AggMetricConfig], AggMetricConfig, dict] ] = None version: Optional[str] = None metadata: Optional[dict] = ( None # by default, not used in the code. allows for users to pass arbitrary info to tasks ) def __getitem__(self, item): return getattr(self, item) def __setitem__(self, item, value): return setattr(self, item, value) def __contains__(self, item): """Support 'in' operator for dict-like behavior.""" return hasattr(self, item) def get(self, key, default=None): """Dict-like get method.""" return getattr(self, key, default) def __hash__(self): """Make GroupConfig hashable based on group name.""" return hash(self.group) def __eq__(self, other): """Equality comparison based on group name.""" if not isinstance(other, GroupConfig): return False return self.group == other.group def __post_init__(self): if self.aggregate_metric_list is not None: if isinstance(self.aggregate_metric_list, dict): self.aggregate_metric_list = [self.aggregate_metric_list] self.aggregate_metric_list = [ AggMetricConfig(**item) if isinstance(item, dict) else item for item in self.aggregate_metric_list ] self.version = ( self.version or self.metadata.get("version", "1.0") if self.metadata else "1.0" ) def to_dict(self, keep_callable: bool = False) -> dict: """dumps the current config as a dictionary object, as a printable format. null fields will not be printed. Used for dumping results alongside full task configuration :return: dict A printable dictionary version of the TaskConfig object. # TODO: should any default value in the TaskConfig not be printed? """ cfg_dict = asdict(self) # remove values that are `None` for k, v in list(cfg_dict.items()): if callable(v): cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) return cfg_dict def serialize_function( self, value: Union[Callable, str], keep_callable=False ) -> Union[Callable, str]: """Serializes a given function or string. If 'keep_callable' is True, the original callable is returned. Otherwise, attempts to return the source code of the callable using 'getsource'. """ if keep_callable: return value else: try: return getsource(value) except (TypeError, OSError): return str(value) def __repr__(self): return f"GroupConfig(group={self.group},group_alias={self.group_alias})"