Commit 24b7e2d6 authored by Baber's avatar Baber
Browse files

type hints

parent 9f345f33
from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -36,13 +38,14 @@ def register_model(*names): ...@@ -36,13 +38,14 @@ def register_model(*names):
return decorate return decorate
def get_model(model_name: str) -> type["LM"]: def get_model(model_name: str) -> type[LM]:
try: try:
return MODEL_REGISTRY[model_name] return MODEL_REGISTRY[model_name]
except KeyError: except KeyError as err:
raise ValueError( available_models = ", ".join(MODEL_REGISTRY.keys())
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" raise KeyError(
) f"Model '{model_name}' not found. Available models: {available_models}"
) from err
TASK_REGISTRY = {} TASK_REGISTRY = {}
...@@ -81,7 +84,7 @@ def register_group(name): ...@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY = {} OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {} METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {} METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} AGGREGATION_REGISTRY: dict[str, Callable[[], dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {} HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {} FILTER_REGISTRY = {}
...@@ -125,7 +128,7 @@ def register_metric(**args): ...@@ -125,7 +128,7 @@ def register_metric(**args):
return decorate return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]: def get_metric(name: str, hf_evaluate_metric=False) -> Callable[..., Any] | None:
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
...@@ -157,21 +160,21 @@ def register_aggregation(name: str): ...@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return decorate return decorate
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]: def get_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!") eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]: def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name: str) -> Optional[bool]: def is_higher_better(metric_name: str) -> bool | None:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
...@@ -192,7 +195,7 @@ def register_filter(name: str): ...@@ -192,7 +195,7 @@ def register_filter(name: str):
return decorate return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable: def get_filter(filter_name: str | Callable) -> Callable:
try: try:
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
except KeyError as e: except KeyError as e:
......
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import Any, Callable, List, Optional from typing import Any
@dataclass @dataclass
...@@ -8,9 +11,9 @@ class MetricConfig: ...@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric.""" """Encapsulates information about a single metric."""
name: str name: str
fn: Optional[Callable] = None fn: Callable | None = None
kwargs: Optional[dict] = None kwargs: dict | None = None
aggregation_fn: Optional[Callable] = None aggregation_fn: Callable | None = None
higher_is_better: bool = True higher_is_better: bool = True
hf_evaluate: bool = False hf_evaluate: bool = False
is_elementwise: bool = True is_elementwise: bool = True
...@@ -41,7 +44,7 @@ class MetricConfig: ...@@ -41,7 +44,7 @@ class MetricConfig:
raise ValueError(f"Metric function for {self.name} is not defined.") raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs}) return self.fn(*args, **{**self.kwargs, **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any: def compute_aggregation(self, values: list[Any]) -> Any:
"""Computes the aggregation of the metric values.""" """Computes the aggregation of the metric values."""
if self.aggregation_fn is None: if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.") raise ValueError(f"Aggregation function for {self.name} is not defined.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment