from dataclasses import dataclass, field from typing import TYPE_CHECKING, Callable, Optional, Union if TYPE_CHECKING: from lm_eval.config.metric import MetricConfig @dataclass class TemplateConfig: """Encapsulates information about a template.""" template: str doc_to_text: Union[str, Callable[[dict], str]] doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_target: Union[int, Callable[[dict], int]] description: str context_prefix: str prefix_delimiter: str context_delimiter: str answer_suffix: str target_delimiter: str choice_format: Optional[str] choice_delimiter: Optional[str] fewshot_delimiter: str metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field( default_factory=lambda: ["acc", "acc_norm"] ) @dataclass class MCQTemplateConfig: """Encapsulates information about a template. Would return a sample with the following format: Question: A. B. C. D. Answer:` doc_to_choice(doc)` for each choice. """ doc_to_text: Union[str, Callable[[dict], str]] doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_target: Union[int, Callable[[dict], int]] template = "mcq" context_prefix: str = "Question:" prefix_delimiter: str = " " context_delimiter: str = "\n" answer_suffix: str = "Answer:" target_delimiter: str = "\n" choice_format: Optional[str] = "letters" choice_delimiter: Optional[str] = "\n" fewshot_delimiter: str = "\n\n" metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"]) @dataclass class ClozeTemplateConfig: """Encapsulates information about a template. Would return a sample with the following format: Question: Answer:` ` """ doc_to_text: Union[str, Callable[[dict], str]] doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_target: Union[int, Callable[[dict], int]] template: str = "cloze" description: str = "" context_prefix: str = "Question:" prefix_delimiter: str = " " context_delimiter: str = "\n" answer_suffix: str = "Answer:" target_delimiter: str = " " choice_format: Optional[str] = None choice_delimiter: Optional[str] = None fewshot_delimiter: str = "\n\n" metric_list: Optional[list["MetricConfig"]] = field( default_factory=lambda: ["acc", "acc_norm"] )