instance.py 1.1 KB
Newer Older
1
from dataclasses import dataclass, field
Baber's avatar
Baber committed
2
3
4
from typing import Generic, Literal, Optional, Tuple, TypeVar

from lm_eval.api.types import GenerateInput, LoglikelihoodInput
5
6
7
8
9


OutputType = Literal[
    "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
10

Baber's avatar
Baber committed
11
12
T = TypeVar("T", LoglikelihoodInput, GenerateInput)

lintangsutawika's avatar
lintangsutawika committed
13

14
@dataclass
Baber's avatar
Baber committed
15
class Instance(Generic[T]):
16
    request_type: OutputType
haileyschoelkopf's avatar
haileyschoelkopf committed
17
    doc: dict
Baber's avatar
Baber committed
18
    arguments: T
haileyschoelkopf's avatar
haileyschoelkopf committed
19
    idx: int
20
    metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
haileyschoelkopf's avatar
haileyschoelkopf committed
21
        default_factory=lambda: (None, None, None)
22
    )
23
24
25
    resps: list = field(default_factory=list)
    filtered_resps: dict = field(default_factory=dict)

26
    # initialized after init
27
28
29
    task_name: Optional[str] = None
    doc_id: Optional[int] = None
    repeats: Optional[int] = None
30

Ethan Smith's avatar
Ethan Smith committed
31
    def __post_init__(self) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
32
        # unpack metadata field
33
        self.task_name, self.doc_id, self.repeats = self.metadata
lintangsutawika's avatar
lintangsutawika committed
34

35
    @property
Baber's avatar
Baber committed
36
    def args(self) -> T:
37
38
39
        """
        Returns (string,) where `string` is the string to calculate loglikelihood over
        """
Baber's avatar
Baber committed
40
        return self.arguments