schemas.py 2.54 KB
Newer Older
Baber's avatar
Baber committed
1
from dataclasses import dataclass
Baber's avatar
Baber committed
2
from typing import Optional, Union
Baber's avatar
Baber committed
3
4


Baber's avatar
Baber committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# @dataclass
# class GenerateInput:
#     """
#     Inputs for the generate function.
#     """
#
#     prompt: str
#     gen_kwargs: dict
#     multimodal_arg: Optional[dict] = None
#
#     def __iter__(self):
#         return (
#             iter((self.prompt, self.gen_kwargs))
#             if not self.multimodal_arg
#             else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
#         )
#
#     def __getitem__(self, item: int):
#         return [self.prompt, self.gen_kwargs][item]
#
#
# @dataclass
# class GenerateOutput:
#     """
#     Outputs for the generate function.
#     """
#
#     text: str
#     metadata: dict = None
#
#
# @dataclass
# class LoglikelihoodInput:
#     """
#     Inputs for the loglikelihood function.
#     """
#
#     context: str
#     continuation: Optional[str] = None
#
#
# class LoglikelihoodOutput(NamedTuple):
#     """
#     Outputs for the loglikelihood function.
#     """
#
#     loglikelihood: float
#     is_greedy: Optional[bool] = None
#     ctx_tokens: Optional[list[int]] = None
#     cont_tokens: Optional[list[int]] = None
#     metadata: Optional[dict] = None

# def __iter__(self):
#     return iter((self.loglikelihood, self.is_greedy))
59
60
61
62
63
64
65
66


@dataclass
class MetricResult:
    """
    Outputs for the metric function.
    """

Baber's avatar
Baber committed
67
    doc_id: Union[str, int]
68
69
70
    filter_key: str = None
    metric_name: str = None
    metadata: Optional[dict] = None
Baber's avatar
Baber committed
71
    scores: Union[list[dict[str, float]], dict] = None
72
73
74
75
76
77

    def __iter__(self):
        if self.scores is None:
            return iter([])

        # Group values by metric key
Baber's avatar
Baber committed
78
79
        if not isinstance(self.scores, list):
            self.scores = [self.scores]
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        grouped = {}
        for score_dict in self.scores:
            for key, value in score_dict.items():
                if key not in grouped:
                    grouped[key] = []
                grouped[key].append(value)

        # Return iterator of (key, list[values]) pairs
        return iter(grouped.items())

    def get_metric_results(self, metric_key) -> list[float]:
        if self.scores is None:
            return []
        return [
            score_dict[metric_key]
            for score_dict in self.scores
            if metric_key in score_dict
        ]
98
99
100
101
102

    @property
    def metric_keys(self) -> list[str]:
        if self.scores is None:
            return []
Baber's avatar
Baber committed
103
104
105
106
107
        return (
            list(self.scores[0].keys())
            if isinstance(self.scores, list)
            else list(self.scores.keys())
        )