instance.py 885 Bytes
Newer Older
1
from dataclasses import dataclass, field
haileyschoelkopf's avatar
haileyschoelkopf committed
2
from typing import Literal
3
4
5

@dataclass
class Instance:
haileyschoelkopf's avatar
haileyschoelkopf committed
6
    request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"]
7
8
9
10
11
12
13
    doc: dict = None
    arguments: tuple = None
    id_: int = None
    metadata: tuple = None # TODO: better typehints here
    resps: list = field(default_factory=list)
    filtered_resps: dict = field(default_factory=dict)

14
    # initialized after init
15
16
17
18
19
    task_name: str = None
    doc_id: str = None
    repeats: str = None

    def __post_init__(self):
haileyschoelkopf's avatar
haileyschoelkopf committed
20
        # unpack metadata field
21
22
23
24
25
26
27
28
        self.task_name, self.doc_id, self.repeats = self.metadata
     
    @property
    def args(self):
        """
        Returns (string,) where `string` is the string to calculate loglikelihood over
        """
        return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)