group.py 3.71 KB
Newer Older
Lintang Sutawika's avatar
Lintang Sutawika committed
1
2
from dataclasses import asdict, dataclass
from inspect import getsource
3
from typing import Callable, List, Optional, Union
Lintang Sutawika's avatar
Lintang Sutawika committed
4
5
6
7
8
9
10
11
12
13
14


@dataclass
class AggMetricConfig(dict):
    metric: Optional[str] = None
    aggregation: Optional[str] = "mean"
    weight_by_size: Optional[str] = False
    # list of filter names which should be incorporated into the aggregated metric.
    filter_list: Optional[Union[str, list]] = "none"

    def __post_init__(self):
am-bean's avatar
am-bean committed
15
        if self.aggregation != "mean" and not callable(self.aggregation):
Lintang Sutawika's avatar
Lintang Sutawika committed
16
            raise ValueError(
am-bean's avatar
am-bean committed
17
                f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
Lintang Sutawika's avatar
Lintang Sutawika committed
18
19
20
21
22
23
24
            )

        if isinstance(self.filter_list, str):
            self.filter_list = [self.filter_list]


@dataclass
25
class GroupConfig:
Lintang Sutawika's avatar
Lintang Sutawika committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    group: Optional[str] = None
    group_alias: Optional[str] = None
    task: Optional[Union[str, list]] = None
    aggregate_metric_list: Optional[
        Union[List[AggMetricConfig], AggMetricConfig, dict]
    ] = None
    metadata: Optional[dict] = (
        None  # by default, not used in the code. allows for users to pass arbitrary info to tasks
    )

    def __getitem__(self, item):
        return getattr(self, item)

    def __setitem__(self, item, value):
        return setattr(self, item, value)

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    def __contains__(self, item):
        """Support 'in' operator for dict-like behavior."""
        return hasattr(self, item)

    def get(self, key, default=None):
        """Dict-like get method."""
        return getattr(self, key, default)

    def __hash__(self):
        """Make GroupConfig hashable based on group name."""
        return hash(self.group)

    def __eq__(self, other):
        """Equality comparison based on group name."""
        if not isinstance(other, GroupConfig):
            return False
        return self.group == other.group

Lintang Sutawika's avatar
Lintang Sutawika committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    def __post_init__(self):
        if self.aggregate_metric_list is not None:
            if isinstance(self.aggregate_metric_list, dict):
                self.aggregate_metric_list = [self.aggregate_metric_list]

            self.aggregate_metric_list = [
                AggMetricConfig(**item) if isinstance(item, dict) else item
                for item in self.aggregate_metric_list
            ]

    def to_dict(self, keep_callable: bool = False) -> dict:
        """dumps the current config as a dictionary object, as a printable format.
        null fields will not be printed.
        Used for dumping results alongside full task configuration

        :return: dict
            A printable dictionary version of the TaskConfig object.

        # TODO: should any default value in the TaskConfig not be printed?
        """
        cfg_dict = asdict(self)
        # remove values that are `None`
        for k, v in list(cfg_dict.items()):
            if callable(v):
                cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
        return cfg_dict

    def serialize_function(
        self, value: Union[Callable, str], keep_callable=False
    ) -> Union[Callable, str]:
        """Serializes a given function or string.

        If 'keep_callable' is True, the original callable is returned.
        Otherwise, attempts to return the source code of the callable using 'getsource'.
        """
        if keep_callable:
            return value
        else:
            try:
                return getsource(value)
            except (TypeError, OSError):
                return str(value)

103
104
105
106
107
    @property
    def version(self) -> str:
        """Returns the version of the group configuration."""
        return self.metadata.get("version", "1.0")

Lintang Sutawika's avatar
Lintang Sutawika committed
108
    def __repr__(self):
109
        return f"GroupConfig(group={self.group},group_alias={self.group_alias})"