Commit 15d07121 authored by Baber's avatar Baber
Browse files

add temlplateconfigs

parent 3ba4e897
......@@ -5,7 +5,7 @@ import random
import re
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from functools import cached_property
from inspect import getsource
from typing import (
......@@ -87,6 +87,12 @@ class MetricConfig:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values)
@dataclass
class RepeatConfig:
......@@ -111,6 +117,70 @@ class FewshotConfig:
sampler: str
samples: list[dict]
process_docs: Optional[Callable] = None
fewshot_indices: Optional[list[int]] = None
@dataclass
class TemplateConfig:
"""Encapsulates information about a template."""
template: str
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: Optional[str]
choice_delimiter: Optional[str]
fewshot_delimiter: str
metric_list: Optional[Union[list[str], list[MetricConfig]]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@dataclass
class MCQTemplateConfig:
"""Encapsulates information about a template."""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: Optional[str] = "letters"
choice_delimiter: Optional[str] = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list[MetricConfig]] = field(default_factory=lambda: ["acc"])
@dataclass
class ClozeTemplateConfig:
"""Encapsulates information about a template."""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: Optional[str] = None
choice_delimiter: Optional[str] = None
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list[MetricConfig]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment