instance.py 1.44 KB
Newer Older
1
from dataclasses import dataclass, field
2
3
4
5
6
7
from typing import Literal, Optional, Tuple


OutputType = Literal[
    "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
8

lintangsutawika's avatar
lintangsutawika committed
9

10
11
@dataclass
class Instance:
12
    request_type: OutputType
haileyschoelkopf's avatar
haileyschoelkopf committed
13
14
15
    doc: dict
    arguments: tuple
    idx: int
16
    metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
Baber's avatar
Baber committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        default_factory=lambda: (None, None, None),
        metadata=dict(
            description="Metadata tuple containing task name, document ID, and number of repeats."
        ),
    )
    resps: list = field(
        default_factory=list,
        metadata=dict(
            description="List of responses from the model for this instance."
        ),
    )
    filtered_resps: dict = field(
        default_factory=dict,
        metadata=dict(
            description="List of filtered responses for this instance, keyed by filter name."
        ),
33
    )
34

35
    # initialized after init
36
37
38
    task_name: Optional[str] = None
    doc_id: Optional[int] = None
    repeats: Optional[int] = None
39

Ethan Smith's avatar
Ethan Smith committed
40
    def __post_init__(self) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
41
        # unpack metadata field
42
        self.task_name, self.doc_id, self.repeats = self.metadata
lintangsutawika's avatar
lintangsutawika committed
43

44
    @property
Baber's avatar
Baber committed
45
    def args(self) -> tuple:
46
47
48
        """
        Returns (string,) where `string` is the string to calculate loglikelihood over
        """
lintangsutawika's avatar
lintangsutawika committed
49
50
51
        return (
            self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
        )