instance.py 1.15 KB
Newer Older
1
from dataclasses import dataclass, field
2
3
4
5
6
7
from typing import Literal, Optional, Tuple


OutputType = Literal[
    "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
8

Ashvin Nihalani's avatar
Ashvin Nihalani committed
9
10
11
12
InputType = Literal[
    "text", "text_image"
]

lintangsutawika's avatar
lintangsutawika committed
13

14
15
@dataclass
class Instance:
16
    request_type: OutputType
haileyschoelkopf's avatar
haileyschoelkopf committed
17
18
19
    doc: dict
    arguments: tuple
    idx: int
Ashvin Nihalani's avatar
Ashvin Nihalani committed
20
21
22
23

    # Input type for multimodal
    input_type: InputType = "text"

24
    metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
haileyschoelkopf's avatar
haileyschoelkopf committed
25
        default_factory=lambda: (None, None, None)
26
    )
27
28
29
    resps: list = field(default_factory=list)
    filtered_resps: dict = field(default_factory=dict)

30
    # initialized after init
31
32
33
    task_name: Optional[str] = None
    doc_id: Optional[int] = None
    repeats: Optional[int] = None
34

Ashvin Nihalani's avatar
Ashvin Nihalani committed
35

Ethan Smith's avatar
Ethan Smith committed
36
    def __post_init__(self) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
37
        # unpack metadata field
38
        self.task_name, self.doc_id, self.repeats = self.metadata
lintangsutawika's avatar
lintangsutawika committed
39

40
41
42
43
44
    @property
    def args(self):
        """
        Returns (string,) where `string` is the string to calculate loglikelihood over
        """
lintangsutawika's avatar
lintangsutawika committed
45
46
47
        return (
            self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
        )