group.py 3.73 KB
Newer Older
Lintang Sutawika's avatar
Lintang Sutawika committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Callable, List, Optional, Union


@dataclass
class AggMetricConfig(dict):
    metric: Optional[str] = None
    aggregation: Optional[str] = "mean"
    weight_by_size: Optional[str] = False
    # list of filter names which should be incorporated into the aggregated metric.
    filter_list: Optional[Union[str, list]] = "none"

    def __post_init__(self):
am-bean's avatar
am-bean committed
15
        if self.aggregation != "mean" and not callable(self.aggregation):
Lintang Sutawika's avatar
Lintang Sutawika committed
16
            raise ValueError(
am-bean's avatar
am-bean committed
17
                f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
Lintang Sutawika's avatar
Lintang Sutawika committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
            )

        if isinstance(self.filter_list, str):
            self.filter_list = [self.filter_list]


@dataclass
class GroupConfig(dict):
    group: Optional[str] = None
    group_alias: Optional[str] = None
    task: Optional[Union[str, list]] = None
    aggregate_metric_list: Optional[
        Union[List[AggMetricConfig], AggMetricConfig, dict]
    ] = None
32
    version: Optional[str] = None
Lintang Sutawika's avatar
Lintang Sutawika committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    metadata: Optional[dict] = (
        None  # by default, not used in the code. allows for users to pass arbitrary info to tasks
    )

    def __getitem__(self, item):
        return getattr(self, item)

    def __setitem__(self, item, value):
        return setattr(self, item, value)

    def __post_init__(self):
        if self.aggregate_metric_list is not None:
            if isinstance(self.aggregate_metric_list, dict):
                self.aggregate_metric_list = [self.aggregate_metric_list]

            self.aggregate_metric_list = [
                AggMetricConfig(**item) if isinstance(item, dict) else item
                for item in self.aggregate_metric_list
            ]
52
53
54
55
56
        self.version = (
            self.version or self.metadata.get("version", "1.0")
            if self.metadata
            else "1.0"
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

    def to_dict(self, keep_callable: bool = False) -> dict:
        """dumps the current config as a dictionary object, as a printable format.
        null fields will not be printed.
        Used for dumping results alongside full task configuration

        :return: dict
            A printable dictionary version of the TaskConfig object.

        # TODO: should any default value in the TaskConfig not be printed?
        """
        cfg_dict = asdict(self)
        # remove values that are `None`
        for k, v in list(cfg_dict.items()):
            if callable(v):
                cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
        return cfg_dict

    def serialize_function(
        self, value: Union[Callable, str], keep_callable=False
    ) -> Union[Callable, str]:
        """Serializes a given function or string.

        If 'keep_callable' is True, the original callable is returned.
        Otherwise, attempts to return the source code of the callable using 'getsource'.
        """
        if keep_callable:
            return value
        else:
            try:
                return getsource(value)
            except (TypeError, OSError):
                return str(value)


Baber's avatar
Baber committed
92
class ConfigurableGroup:
Lintang Sutawika's avatar
Lintang Sutawika committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    def __init__(
        self,
        config: Optional[dict] = None,
    ) -> None:
        self._config = GroupConfig(**config)

    @property
    def group(self):
        return self._config.group

    @property
    def group_alias(self):
        return self._config.group_alias

    @property
    def version(self):
        return self._config.version

    @property
    def config(self):
        return self._config.to_dict()

    @property
    def group_name(self) -> Any:
        return self._config.group

    def __repr__(self):
Baber Abbasi's avatar
Baber Abbasi committed
120
        return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"