group.py 3.74 KB
Newer Older
1
from dataclasses import asdict, dataclass, field
Lintang Sutawika's avatar
Lintang Sutawika committed
2
from inspect import getsource
Baber's avatar
Baber committed
3
4
from typing import Callable, Optional, Union

Lintang Sutawika's avatar
Lintang Sutawika committed
5
6
7
8
9

@dataclass
class AggMetricConfig(dict):
    metric: Optional[str] = None
    aggregation: Optional[str] = "mean"
Baber's avatar
Baber committed
10
    weight_by_size: bool = False
Lintang Sutawika's avatar
Lintang Sutawika committed
11
12
13
14
    # list of filter names which should be incorporated into the aggregated metric.
    filter_list: Optional[Union[str, list]] = "none"

    def __post_init__(self):
am-bean's avatar
am-bean committed
15
        if self.aggregation != "mean" and not callable(self.aggregation):
Lintang Sutawika's avatar
Lintang Sutawika committed
16
            raise ValueError(
am-bean's avatar
am-bean committed
17
                f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
Lintang Sutawika's avatar
Lintang Sutawika committed
18
19
20
21
22
23
24
            )

        if isinstance(self.filter_list, str):
            self.filter_list = [self.filter_list]


@dataclass
25
class GroupConfig:
Lintang Sutawika's avatar
Lintang Sutawika committed
26
27
    group: Optional[str] = None
    group_alias: Optional[str] = None
Baber's avatar
Baber committed
28
    task: Union[str, list] = field(default_factory=list)
Lintang Sutawika's avatar
Lintang Sutawika committed
29
    aggregate_metric_list: Optional[
Baber's avatar
Baber committed
30
        Union[list[AggMetricConfig], AggMetricConfig, dict]
Lintang Sutawika's avatar
Lintang Sutawika committed
31
    ] = None
32
    version: Optional[str] = None
Lintang Sutawika's avatar
Lintang Sutawika committed
33
34
35
36
37
38
39
40
41
42
    metadata: Optional[dict] = (
        None  # by default, not used in the code. allows for users to pass arbitrary info to tasks
    )

    def __getitem__(self, item):
        return getattr(self, item)

    def __setitem__(self, item, value):
        return setattr(self, item, value)

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def __contains__(self, item):
        """Support 'in' operator for dict-like behavior."""
        return hasattr(self, item)

    def get(self, key, default=None):
        """Dict-like get method."""
        return getattr(self, key, default)

    def __hash__(self):
        """Make GroupConfig hashable based on group name."""
        return hash(self.group)

    def __eq__(self, other):
        """Equality comparison based on group name."""
        if not isinstance(other, GroupConfig):
            return False
        return self.group == other.group

Lintang Sutawika's avatar
Lintang Sutawika committed
61
62
63
64
65
66
67
68
69
    def __post_init__(self):
        if self.aggregate_metric_list is not None:
            if isinstance(self.aggregate_metric_list, dict):
                self.aggregate_metric_list = [self.aggregate_metric_list]

            self.aggregate_metric_list = [
                AggMetricConfig(**item) if isinstance(item, dict) else item
                for item in self.aggregate_metric_list
            ]
70
71
72
73
74
        self.version = (
            self.version or self.metadata.get("version", "1.0")
            if self.metadata
            else "1.0"
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    def to_dict(self, keep_callable: bool = False) -> dict:
        """dumps the current config as a dictionary object, as a printable format.
        null fields will not be printed.
        Used for dumping results alongside full task configuration

        :return: dict
            A printable dictionary version of the TaskConfig object.

        # TODO: should any default value in the TaskConfig not be printed?
        """
        cfg_dict = asdict(self)
        # remove values that are `None`
        for k, v in list(cfg_dict.items()):
            if callable(v):
                cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
        return cfg_dict

    def serialize_function(
        self, value: Union[Callable, str], keep_callable=False
    ) -> Union[Callable, str]:
        """Serializes a given function or string.

        If 'keep_callable' is True, the original callable is returned.
        Otherwise, attempts to return the source code of the callable using 'getsource'.
        """
        if keep_callable:
            return value
        else:
            try:
                return getsource(value)
            except (TypeError, OSError):
                return str(value)

    def __repr__(self):
110
        return f"GroupConfig(group={self.group},group_alias={self.group_alias})"