Commit 17223113 authored by Baber's avatar Baber
Browse files

type hints

parent 24b7e2d6
......@@ -160,7 +160,7 @@ def register_aggregation(name: str):
return decorate
def get_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
def get_aggregation(name: str) -> Callable[..., Any] | None:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
......
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import cached_property
from typing import Any
......@@ -12,7 +12,7 @@ class MetricConfig:
name: str
fn: Callable | None = None
kwargs: dict | None = None
kwargs: Mapping[str, Any] | None = None
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
......@@ -23,7 +23,7 @@ class MetricConfig:
return self.name
@cached_property
def aggregation(self) -> Callable:
def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
......@@ -31,7 +31,7 @@ class MetricConfig:
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
......@@ -42,7 +42,7 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, values: list[Any]) -> Any:
"""Computes the aggregation of the metric values."""
......
from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from typing import TYPE_CHECKING, Callable
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
......@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N"
kwargs: Optional[dict] = field(default_factory=dict)
metric_fn: str | Callable = "pass@N"
kwargs: dict | None = field(default_factory=dict)
@dataclass
......@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int]
split: Optional[str] = None
sampler: Union[str, Callable] = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None
process_docs: Optional[Callable[[list[dict]], Iterable[dict]]] = None
fewshot_indices: Optional[list[int]] = None
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Callable[[list[dict]], Iterable[dict]] | None = None
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
def __post_init__(self) -> None:
......@@ -65,22 +68,20 @@ class FewshotConfig:
def _get_raw_docs(
self, dataset
) -> Union[list[dict], Callable[[], Iterable[dict]], None]:
) -> list[dict] | Callable[[], Iterable[dict]] | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
if isinstance(self.samples, list) or callable(self.samples):
return self.samples
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> Optional[Iterable[dict]]:
def get_docs(self, dataset) -> Iterable[dict] | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
......@@ -100,8 +101,8 @@ class FewshotConfig:
return self.sampler
def init_sampler(
self, docs: list[dict], task: "Task", rnd=None, fewshot_indices=None
) -> "ContextSampler":
self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> ContextSampler:
"""Initialize the sampler with the given documents and task."""
if rnd is None:
raise ValueError(
......@@ -120,49 +121,49 @@ class FewshotConfig:
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None
task: str | None = None
task_alias: str | None = None
tag: str | list | None = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = field(default_factory=dict)
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[str] = (
custom_dataset: Callable | None = None
dataset_path: str | None = None
dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None
validation_split: str | None = None
test_split: str | None = None
fewshot_split: str | None = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str, None] = None
process_docs: Callable | None = None
doc_to_text: Callable | str | None = None
doc_to_target: Callable | str | None = None
doc_to_image: Callable | str | None = None
doc_to_audio: Callable | str | None = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
doc_to_choice: Callable | str | dict | list | None = None
process_results: Callable | str | None = None
use_prompt: str | None = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None
fewshot_config: dict | None = None
# runtime configuration options
num_fewshot: Optional[int] = 0
generation_kwargs: Optional[dict] = None
num_fewshot: int | None = 0
generation_kwargs: dict | None = None
# scoring options
metric_list: Optional[list] = None
metric_list: list | None = None
output_type: OutputType = "generate_until"
repeats: int = 1
filter_list: Optional[list[dict]] = None
filter_list: list[dict] | None = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = field(
doc_to_decontamination_query: str | None = None
gen_prefix: str | None = None
metadata: dict | None = field(
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
......@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
def _get_metric(
self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -314,7 +313,7 @@ class TaskConfig(dict):
return metrics
@property
def get_filters(self) -> list["FilterConfig"]:
def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
......@@ -354,7 +353,7 @@ class TaskConfig(dict):
return x
@classmethod
def from_yaml(cls, data: dict) -> "TaskConfig":
def from_yaml(cls, data: dict) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
......
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable
if TYPE_CHECKING:
......@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template."""
template: str
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: Optional[str]
choice_delimiter: Optional[str]
choice_format: str | None
choice_delimiter: str | None
fewshot_delimiter: str
metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field(
metric_list: list[str] | list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"]
)
......@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice.
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: Optional[str] = "letters"
choice_delimiter: Optional[str] = "\n"
choice_format: str | None = "letters"
choice_delimiter: str | None = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"])
metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
@dataclass
......@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>`
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
......@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: Optional[str] = None
choice_delimiter: Optional[str] = None
choice_format: str | None = None
choice_delimiter: str | None = None
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(
metric_list: list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"]
)
from __future__ import annotations
from inspect import getsource
from typing import Any, Callable, Union
from typing import Any, Callable
def serialize_callable(
value: Union[Callable[..., Any], str], keep_callable=False
) -> Union[Callable[..., Any], str]:
value: Callable[..., Any] | str, keep_callable=False
) -> Callable[..., Any] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
......@@ -20,9 +22,7 @@ def serialize_callable(
return str(value)
def maybe_serialize(
val: Union[Callable, Any], keep_callable=False
) -> Union[Callable, Any]:
def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
"""Conditionally serializes a value if it is callable."""
return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment