"official/vision/tasks/__init__.py" did not exist on "2d35330680451e079d24d6bddd06f4b322ca160b"
Commit b0173d57 authored by Baber's avatar Baber
Browse files

add temlplateconfigs

parent bbf79d44
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
import re import re
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from functools import cached_property from functools import cached_property
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
...@@ -87,6 +87,12 @@ class MetricConfig: ...@@ -87,6 +87,12 @@ class MetricConfig:
raise ValueError(f"Metric function for {self.name} is not defined.") raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs}) return self.fn(*args, **{**self.kwargs, **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values)
@dataclass @dataclass
class RepeatConfig: class RepeatConfig:
...@@ -111,6 +117,82 @@ class FewshotConfig: ...@@ -111,6 +117,82 @@ class FewshotConfig:
sampler: str sampler: str
samples: list[dict] samples: list[dict]
process_docs: Optional[Callable] = None process_docs: Optional[Callable] = None
fewshot_indices: Optional[list[int]] = None
@dataclass
class TemplateConfig:
"""Encapsulates information about a template."""
template: str
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: Optional[str]
choice_delimiter: Optional[str]
fewshot_delimiter: str
metric_list: Optional[Union[list[str], list[MetricConfig]]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@dataclass
class MCQTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
A. <doc_to_choice(doc)[0]>
B. <doc_to_choice(doc)[1]>
C. <doc_to_choice(doc)[2]>
D. <doc_to_choice(doc)[3]>
Answer:` doc_to_choice(doc)` for each choice.
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: Optional[str] = "letters"
choice_delimiter: Optional[str] = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list[MetricConfig]] = field(default_factory=lambda: ["acc"])
@dataclass
class ClozeTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
Answer:` <doc_to_target(doc)>`
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: Optional[str] = None
choice_delimiter: Optional[str] = None
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list[MetricConfig]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment