"doc/src/vscode:/vscode.git/clone" did not exist on "3c95b34d96670893218bfa6b84488eb8a3e8b169"
Commit 24b7e2d6 authored by Baber's avatar Baber
Browse files

type hints

parent 9f345f33
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
......@@ -36,13 +38,14 @@ def register_model(*names):
return decorate
def get_model(model_name: str) -> type["LM"]:
def get_model(model_name: str) -> type[LM]:
try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
)
except KeyError as err:
available_models = ", ".join(MODEL_REGISTRY.keys())
raise KeyError(
f"Model '{model_name}' not found. Available models: {available_models}"
) from err
TASK_REGISTRY = {}
......@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
AGGREGATION_REGISTRY: dict[str, Callable[[], dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {}
......@@ -125,7 +128,7 @@ def register_metric(**args):
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]:
def get_metric(name: str, hf_evaluate_metric=False) -> Callable[..., Any] | None:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
......@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return decorate
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
def get_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name: str) -> Optional[bool]:
def is_higher_better(metric_name: str) -> bool | None:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......@@ -192,7 +195,7 @@ def register_filter(name: str):
return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable:
def get_filter(filter_name: str | Callable) -> Callable:
try:
return FILTER_REGISTRY[filter_name]
except KeyError as e:
......
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, List, Optional
from typing import Any
@dataclass
......@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
fn: Callable | None = None
kwargs: dict | None = None
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
......@@ -41,7 +44,7 @@ class MetricConfig:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any:
def compute_aggregation(self, values: list[Any]) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment