Commit 5fdeb436 authored by Baber's avatar Baber
Browse files

types

parent f21a0b81
......@@ -3,7 +3,9 @@ from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Union
import datasets
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
......@@ -18,6 +20,8 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
@dataclass
class RepeatConfig:
......@@ -44,10 +48,8 @@ class FewshotConfig:
num_fewshot: Callable[[], int]
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], Iterable[dict]] | Iterable[dict] | None = None
process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = (
None
)
samples: Callable[[], DataSet] | DataSet | None = None
process_docs: Callable[[DataSet], DataSet] | None = None
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
......@@ -84,7 +86,7 @@ class FewshotConfig:
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> Iterable[dict[str, Any]] | None:
def get_docs(self, dataset) -> DataSet | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
......@@ -130,34 +132,34 @@ class TaskConfig:
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Callable | None = None
custom_dataset: Callable[..., DataSet] | None = None
dataset_path: str | None = None
dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None
validation_split: str | None = None
test_split: str | None = None
fewshot_split: str | None = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
fewshot_split: str | None = None
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable | None = None
doc_to_text: Callable | str | None = None
doc_to_target: Callable | str | None = None
doc_to_image: Callable | str | None = None
doc_to_audio: Callable | str | None = None
process_docs: Callable[[DataSet], DataSet] | None = None
doc_to_text: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_target: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_image: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_audio: Callable[[dict[str, Any]], Any] | str | None = None
unsafe_code: bool = False
doc_to_choice: Callable | str | dict | list | None = None
process_results: Callable | str | None = None
doc_to_choice: Callable[[dict[str, Any]], Any] | str | dict | list | None = None
process_results: (
Callable[[dict[str, Any], list[Any]], dict[str, Any]] | str | None
) = None
use_prompt: str | None = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: dict | None = None
fewshot_config: dict[str, Any] | None = None
# runtime configuration options
num_fewshot: int | None = 0
generation_kwargs: dict | None = None
generation_kwargs: dict[str, Any] | None = None
# scoring options
metric_list: list | None = None
output_type: OutputType = "generate_until"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment