from dataclasses import dataclass, field from typing import Generic, Literal, Optional, Tuple, TypeVar from lm_eval.api.types import GenerateInput, LoglikelihoodInput OutputType = Literal[ "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice" ] T = TypeVar("T", LoglikelihoodInput, GenerateInput) @dataclass class Instance(Generic[T]): request_type: OutputType doc: dict arguments: T idx: int metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( default_factory=lambda: (None, None, None) ) resps: list = field(default_factory=list) filtered_resps: dict = field(default_factory=dict) # initialized after init task_name: Optional[str] = None doc_id: Optional[int] = None repeats: Optional[int] = None def __post_init__(self) -> None: # unpack metadata field self.task_name, self.doc_id, self.repeats = self.metadata @property def args(self) -> T: """ Returns (string,) where `string` is the string to calculate loglikelihood over """ return self.arguments