Commit 5fdeb436 authored by Baber's avatar Baber
Browse files

types

parent f21a0b81
...@@ -3,7 +3,9 @@ from __future__ import annotations ...@@ -3,7 +3,9 @@ from __future__ import annotations
import logging import logging
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable, Union
import datasets
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
...@@ -18,6 +20,8 @@ if TYPE_CHECKING: ...@@ -18,6 +20,8 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
@dataclass @dataclass
class RepeatConfig: class RepeatConfig:
...@@ -44,10 +48,8 @@ class FewshotConfig: ...@@ -44,10 +48,8 @@ class FewshotConfig:
num_fewshot: Callable[[], int] num_fewshot: Callable[[], int]
split: str | None = None split: str | None = None
sampler: str | Callable = "default" sampler: str | Callable = "default"
samples: Callable[[], Iterable[dict]] | Iterable[dict] | None = None samples: Callable[[], DataSet] | DataSet | None = None
process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = ( process_docs: Callable[[DataSet], DataSet] | None = None
None
)
fewshot_indices: list[int] | None = None fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False) rnd: int = field(init=False, default=False)
...@@ -84,7 +86,7 @@ class FewshotConfig: ...@@ -84,7 +86,7 @@ class FewshotConfig:
"samples must be either a list of dicts or a callable returning a list" "samples must be either a list of dicts or a callable returning a list"
) )
def get_docs(self, dataset) -> Iterable[dict[str, Any]] | None: def get_docs(self, dataset) -> DataSet | None:
"""Get processed documents from configured source.""" """Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset) raw_docs = self._get_raw_docs(dataset)
if raw_docs is None: if raw_docs is None:
...@@ -130,34 +132,34 @@ class TaskConfig: ...@@ -130,34 +132,34 @@ class TaskConfig:
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
custom_dataset: Callable | None = None custom_dataset: Callable[..., DataSet] | None = None
dataset_path: str | None = None dataset_path: str | None = None
dataset_name: str | None = None dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict) dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None training_split: str | None = None
validation_split: str | None = None validation_split: str | None = None
test_split: str | None = None test_split: str | None = None
fewshot_split: str | None = ( fewshot_split: str | None = None
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Callable | None = None process_docs: Callable[[DataSet], DataSet] | None = None
doc_to_text: Callable | str | None = None doc_to_text: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_target: Callable | str | None = None doc_to_target: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_image: Callable | str | None = None doc_to_image: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_audio: Callable | str | None = None doc_to_audio: Callable[[dict[str, Any]], Any] | str | None = None
unsafe_code: bool = False unsafe_code: bool = False
doc_to_choice: Callable | str | dict | list | None = None doc_to_choice: Callable[[dict[str, Any]], Any] | str | dict | list | None = None
process_results: Callable | str | None = None process_results: (
Callable[[dict[str, Any], list[Any]], dict[str, Any]] | str | None
) = None
use_prompt: str | None = None use_prompt: str | None = None
description: str = "" description: str = ""
target_delimiter: str = " " target_delimiter: str = " "
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: dict | None = None fewshot_config: dict[str, Any] | None = None
# runtime configuration options # runtime configuration options
num_fewshot: int | None = 0 num_fewshot: int | None = 0
generation_kwargs: dict | None = None generation_kwargs: dict[str, Any] | None = None
# scoring options # scoring options
metric_list: list | None = None metric_list: list | None = None
output_type: OutputType = "generate_until" output_type: OutputType = "generate_until"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment