Commit 17223113 authored by Baber's avatar Baber
Browse files

type hints

parent 24b7e2d6
...@@ -160,7 +160,7 @@ def register_aggregation(name: str): ...@@ -160,7 +160,7 @@ def register_aggregation(name: str):
return decorate return decorate
def get_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None: def get_aggregation(name: str) -> Callable[..., Any] | None:
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
......
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import Any from typing import Any
...@@ -12,7 +12,7 @@ class MetricConfig: ...@@ -12,7 +12,7 @@ class MetricConfig:
name: str name: str
fn: Callable | None = None fn: Callable | None = None
kwargs: dict | None = None kwargs: Mapping[str, Any] | None = None
aggregation_fn: Callable | None = None aggregation_fn: Callable | None = None
higher_is_better: bool = True higher_is_better: bool = True
hf_evaluate: bool = False hf_evaluate: bool = False
...@@ -23,7 +23,7 @@ class MetricConfig: ...@@ -23,7 +23,7 @@ class MetricConfig:
return self.name return self.name
@cached_property @cached_property
def aggregation(self) -> Callable: def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None: if self.aggregation_fn is None:
...@@ -31,7 +31,7 @@ class MetricConfig: ...@@ -31,7 +31,7 @@ class MetricConfig:
return self.aggregation_fn return self.aggregation_fn
@cached_property @cached_property
def _higher_is_better(self) -> bool: def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None: if self.higher_is_better is None:
...@@ -42,7 +42,7 @@ class MetricConfig: ...@@ -42,7 +42,7 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments.""" """Calculates the metric using the provided function and arguments."""
if self.fn is None: if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.") raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs}) return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, values: list[Any]) -> Any: def compute_aggregation(self, values: list[Any]) -> Any:
"""Computes the aggregation of the metric values.""" """Computes the aggregation of the metric values."""
......
from __future__ import annotations
import logging import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union from typing import TYPE_CHECKING, Callable
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
...@@ -20,8 +23,8 @@ class RepeatConfig: ...@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat.""" """Encapsulates information about a single repeat."""
repeats: int = 1 repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N" metric_fn: str | Callable = "pass@N"
kwargs: Optional[dict] = field(default_factory=dict) kwargs: dict | None = field(default_factory=dict)
@dataclass @dataclass
...@@ -38,11 +41,11 @@ class FewshotConfig: ...@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot # hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified # to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int] num_fewshot: Callable[[], int]
split: Optional[str] = None split: str | None = None
sampler: Union[str, Callable] = "default" sampler: str | Callable = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Optional[Callable[[list[dict]], Iterable[dict]]] = None process_docs: Callable[[list[dict]], Iterable[dict]] | None = None
fewshot_indices: Optional[list[int]] = None fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False) rnd: int = field(init=False, default=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -65,22 +68,20 @@ class FewshotConfig: ...@@ -65,22 +68,20 @@ class FewshotConfig:
def _get_raw_docs( def _get_raw_docs(
self, dataset self, dataset
) -> Union[list[dict], Callable[[], Iterable[dict]], None]: ) -> list[dict] | Callable[[], Iterable[dict]] | None:
"""Get raw documents from configured source.""" """Get raw documents from configured source."""
if self.split is not None: if self.split is not None:
return dataset[self.split] return dataset[self.split]
if self.samples is not None: if self.samples is not None:
if isinstance(self.samples, list): if isinstance(self.samples, list) or callable(self.samples):
return self.samples
elif callable(self.samples):
return self.samples return self.samples
else: else:
raise TypeError( raise TypeError(
"samples must be either a list of dicts or a callable returning a list" "samples must be either a list of dicts or a callable returning a list"
) )
def get_docs(self, dataset) -> Optional[Iterable[dict]]: def get_docs(self, dataset) -> Iterable[dict] | None:
"""Get processed documents from configured source.""" """Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset) raw_docs = self._get_raw_docs(dataset)
if raw_docs is None: if raw_docs is None:
...@@ -100,8 +101,8 @@ class FewshotConfig: ...@@ -100,8 +101,8 @@ class FewshotConfig:
return self.sampler return self.sampler
def init_sampler( def init_sampler(
self, docs: list[dict], task: "Task", rnd=None, fewshot_indices=None self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> "ContextSampler": ) -> ContextSampler:
"""Initialize the sampler with the given documents and task.""" """Initialize the sampler with the given documents and task."""
if rnd is None: if rnd is None:
raise ValueError( raise ValueError(
...@@ -120,49 +121,49 @@ class FewshotConfig: ...@@ -120,49 +121,49 @@ class FewshotConfig:
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: Optional[str] = None task: str | None = None
task_alias: Optional[str] = None task_alias: str | None = None
tag: Optional[Union[str, list]] = None tag: str | list | None = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
custom_dataset: Optional[Callable] = None custom_dataset: Callable | None = None
dataset_path: Optional[str] = None dataset_path: str | None = None
dataset_name: Optional[str] = None dataset_name: str | None = None
dataset_kwargs: Optional[dict] = field(default_factory=dict) dataset_kwargs: dict | None = field(default_factory=dict)
training_split: Optional[str] = None training_split: str | None = None
validation_split: Optional[str] = None validation_split: str | None = None
test_split: Optional[str] = None test_split: str | None = None
fewshot_split: Optional[str] = ( fewshot_split: str | None = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?) None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
) )
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None process_docs: Callable | None = None
doc_to_text: Optional[Union[Callable, str]] = None doc_to_text: Callable | str | None = None
doc_to_target: Optional[Union[Callable, str]] = None doc_to_target: Callable | str | None = None
doc_to_image: Union[Callable, str, None] = None doc_to_image: Callable | str | None = None
doc_to_audio: Union[Callable, str, None] = None doc_to_audio: Callable | str | None = None
unsafe_code: bool = False unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None doc_to_choice: Callable | str | dict | list | None = None
process_results: Optional[Union[Callable, str]] = None process_results: Callable | str | None = None
use_prompt: Optional[str] = None use_prompt: str | None = None
description: str = "" description: str = ""
target_delimiter: str = " " target_delimiter: str = " "
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None fewshot_config: dict | None = None
# runtime configuration options # runtime configuration options
num_fewshot: Optional[int] = 0 num_fewshot: int | None = 0
generation_kwargs: Optional[dict] = None generation_kwargs: dict | None = None
# scoring options # scoring options
metric_list: Optional[list] = None metric_list: list | None = None
output_type: OutputType = "generate_until" output_type: OutputType = "generate_until"
repeats: int = 1 repeats: int = 1
filter_list: Optional[list[dict]] = None filter_list: list[dict] | None = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: str | None = None
gen_prefix: Optional[str] = None gen_prefix: str | None = None
metadata: Optional[dict] = field( metadata: dict | None = field(
default_factory=dict default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks ) # by default, not used in the code. allows for users to pass arbitrary info to tasks
...@@ -215,9 +216,7 @@ class TaskConfig(dict): ...@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None), fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
) )
def _get_metric( def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -314,7 +313,7 @@ class TaskConfig(dict): ...@@ -314,7 +313,7 @@ class TaskConfig(dict):
return metrics return metrics
@property @property
def get_filters(self) -> list["FilterConfig"]: def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
if not self.filter_list: if not self.filter_list:
...@@ -354,7 +353,7 @@ class TaskConfig(dict): ...@@ -354,7 +353,7 @@ class TaskConfig(dict):
return x return x
@classmethod @classmethod
def from_yaml(cls, data: dict) -> "TaskConfig": def from_yaml(cls, data: dict) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary.""" """Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data) return cls(**data)
......
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Callable
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -11,19 +13,19 @@ class TemplateConfig: ...@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template.""" """Encapsulates information about a template."""
template: str template: str
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
description: str description: str
context_prefix: str context_prefix: str
prefix_delimiter: str prefix_delimiter: str
context_delimiter: str context_delimiter: str
answer_suffix: str answer_suffix: str
target_delimiter: str target_delimiter: str
choice_format: Optional[str] choice_format: str | None
choice_delimiter: Optional[str] choice_delimiter: str | None
fewshot_delimiter: str fewshot_delimiter: str
metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field( metric_list: list[str] | list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"] default_factory=lambda: ["acc", "acc_norm"]
) )
...@@ -40,19 +42,19 @@ class MCQTemplateConfig: ...@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice. Answer:` doc_to_choice(doc)` for each choice.
""" """
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
template = "mcq" template = "mcq"
context_prefix: str = "Question:" context_prefix: str = "Question:"
prefix_delimiter: str = " " prefix_delimiter: str = " "
context_delimiter: str = "\n" context_delimiter: str = "\n"
answer_suffix: str = "Answer:" answer_suffix: str = "Answer:"
target_delimiter: str = "\n" target_delimiter: str = "\n"
choice_format: Optional[str] = "letters" choice_format: str | None = "letters"
choice_delimiter: Optional[str] = "\n" choice_delimiter: str | None = "\n"
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"]) metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
@dataclass @dataclass
...@@ -63,9 +65,9 @@ class ClozeTemplateConfig: ...@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>` Answer:` <doc_to_target(doc)>`
""" """
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
template: str = "cloze" template: str = "cloze"
description: str = "" description: str = ""
context_prefix: str = "Question:" context_prefix: str = "Question:"
...@@ -73,9 +75,9 @@ class ClozeTemplateConfig: ...@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter: str = "\n" context_delimiter: str = "\n"
answer_suffix: str = "Answer:" answer_suffix: str = "Answer:"
target_delimiter: str = " " target_delimiter: str = " "
choice_format: Optional[str] = None choice_format: str | None = None
choice_delimiter: Optional[str] = None choice_delimiter: str | None = None
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field( metric_list: list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"] default_factory=lambda: ["acc", "acc_norm"]
) )
from __future__ import annotations
from inspect import getsource from inspect import getsource
from typing import Any, Callable, Union from typing import Any, Callable
def serialize_callable( def serialize_callable(
value: Union[Callable[..., Any], str], keep_callable=False value: Callable[..., Any] | str, keep_callable=False
) -> Union[Callable[..., Any], str]: ) -> Callable[..., Any] | str:
"""Serializes a given function or string. """Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned. If 'keep_callable' is True, the original callable is returned.
...@@ -20,9 +22,7 @@ def serialize_callable( ...@@ -20,9 +22,7 @@ def serialize_callable(
return str(value) return str(value)
def maybe_serialize( def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
val: Union[Callable, Any], keep_callable=False
) -> Union[Callable, Any]:
"""Conditionally serializes a value if it is callable.""" """Conditionally serializes a value if it is callable."""
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment