instance.py 1.43 KB
Newer Older
1
from dataclasses import dataclass, field
2
from typing import Literal, Tuple
3

lintangsutawika's avatar
lintangsutawika committed
4

5
6
@dataclass
class Instance:
7
    request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"]
haileyschoelkopf's avatar
haileyschoelkopf committed
8
9
10
11
12
13
    doc: dict
    arguments: tuple
    idx: int
    metadata: Tuple[str, int, int] = field(
        default_factory=lambda: (None, None, None)
    )  # TODO: better typehints here
14
15
16
    resps: list = field(default_factory=list)
    filtered_resps: dict = field(default_factory=dict)

17
    # initialized after init
18
19
20
21
    task_name: str = None
    doc_id: str = None
    repeats: str = None

Ethan Smith's avatar
Ethan Smith committed
22
    def __post_init__(self) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
23
        # unpack metadata field
24
        self.task_name, self.doc_id, self.repeats = self.metadata
lintangsutawika's avatar
lintangsutawika committed
25

26
27
28
29
30
    @property
    def args(self):
        """
        Returns (string,) where `string` is the string to calculate loglikelihood over
        """
lintangsutawika's avatar
lintangsutawika committed
31
32
33
        return (
            self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
34

35
36
37
38
39
40
    @args.setter
    def args(self, new_arguments: tuple) -> None:
        """
        Update the arguments of this instance with a new one
        """
        if isinstance(new_arguments, tuple):
haileyschoelkopf's avatar
haileyschoelkopf committed
41
42
43
            assert (
                len(new_arguments) == len(self.args)
            ), "Must set new Instance arguments to have same size + types as old arguments"
44
45
            self.arguments = new_arguments
        else:
haileyschoelkopf's avatar
haileyschoelkopf committed
46
            raise ValueError("Must set new Instance args to a tuple!")