Commit 48eabc04 authored by Baber's avatar Baber
Browse files

add better type safety

parent 93b2ab37
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
import importlib import importlib
import inspect import inspect
import threading import threading
import warnings
from collections.abc import Iterable, Mapping, MutableMapping from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
...@@ -12,6 +13,7 @@ from typing import ( ...@@ -12,6 +13,7 @@ from typing import (
Callable, Callable,
Generic, Generic,
TypeVar, TypeVar,
overload,
) )
...@@ -92,41 +94,56 @@ class Registry(Generic[T]): ...@@ -92,41 +94,56 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call) # Registration helpers (decorator or direct call)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@overload
def register( def register(
self, self,
*aliases: str, *aliases: str,
lazy: str | md.EntryPoint | None = None, lazy: None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]: ) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``. """Register as decorator: @registry.register("foo")."""
...
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def _do_register(target: T | str | md.EntryPoint) -> None: @overload
def register(
self,
*aliases: str,
lazy: str | md.EntryPoint,
metadata: dict[str, Any] | None = None,
) -> Callable[[Any], Any]:
"""Register lazy: registry.register("foo", lazy="a.b:C")(None)."""
...
def _resolve_aliases(
self, target: T | str | md.EntryPoint, aliases: tuple[str, ...]
) -> tuple[str, ...]:
"""Resolve aliases for registration."""
if not aliases: if not aliases:
_aliases = (getattr(target, "__name__", str(target)),) return (getattr(target, "__name__", str(target)),)
else: return aliases
_aliases = aliases
def _check_and_store(
self,
alias: str,
target: T | str | md.EntryPoint,
metadata: dict[str, Any] | None,
) -> None:
"""Check constraints and store the target with optional metadata.
Collision policy:
1. If alias doesn't exist → store it
2. If identical value → silently succeed (idempotent)
3. If lazy placeholder + matching concrete class → replace with concrete
4. Otherwise → raise ValueError
Type checking:
- Eager for concrete classes at registration time
- Deferred for lazy placeholders until materialization
"""
with self._lock: with self._lock:
for alias in _aliases: # Case 1: New alias
if alias in self._objects: if alias not in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it # Type check concrete classes before storing
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type): if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type] if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError( raise TypeError(
...@@ -134,62 +151,81 @@ class Registry(Generic[T]): ...@@ -134,62 +151,81 @@ class Registry(Generic[T]):
f"to be registered as a {self._name}" f"to be registered as a {self._name}"
) )
self._objects[alias] = target self._objects[alias] = target
# Store metadata if provided
if metadata: if metadata:
self._metadata[alias] = metadata self._metadata[alias] = metadata
return
# ─── decorator path ─── existing = self._objects[alias]
def decorator(obj: T) -> T: # type: ignore[valid-type]
_do_register(obj)
return obj
# ─── direct‑call path with lazy placeholder ─── # Case 2: Identical value - idempotent
if lazy is not None: if existing == target:
_do_register(lazy) return
return lambda x: x # no‑op decorator for accidental use
# Case 3: Lazy placeholder being replaced by its concrete class
if isinstance(existing, str) and isinstance(target, type):
mod_path, _, cls_name = existing.partition(":")
if (
cls_name
and hasattr(target, "__module__")
and hasattr(target, "__name__")
):
expected_path = f"{target.__module__}:{target.__name__}"
if existing == expected_path:
self._objects[alias] = target
if metadata:
self._metadata[alias] = metadata
return
return decorator # Case 4: Collision - different values
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"(existing: {existing}, new: {target})"
)
def register_bulk( def register(
self, self,
items: dict[str, T | str | md.EntryPoint], *aliases: str,
metadata: dict[str, dict[str, Any]] | None = None, lazy: str | md.EntryPoint | None = None,
) -> None: metadata: dict[str, Any] | None = None,
"""Register multiple items at once. ) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
Args: * If called as a **decorator**, supply an object and *no* ``lazy``.
items: Dictionary mapping aliases to objects/lazy paths * If called as a **plain function** and you want lazy import, leave the
metadata: Optional dictionary mapping aliases to metadata object out and pass ``lazy=``.
""" """
with self._lock: # ─── direct‑call path with lazy placeholder ───
for alias, target in items.items(): if lazy is not None:
if alias in self._objects: for alias in self._resolve_aliases(lazy, aliases):
# If it's a lazy placeholder being replaced by the concrete object, allow it self._check_and_store(alias, lazy, metadata)
existing = self._objects[alias] return lambda x: x # no‑op decorator for accidental use
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class # ─── decorator path ───
if self._base_cls is not None and isinstance(target, type): def decorator(obj: T) -> T: # type: ignore[valid-type]
if not issubclass(target, self._base_cls): # type: ignore[arg-type] for alias in self._resolve_aliases(obj, aliases):
raise TypeError( self._check_and_store(alias, obj, metadata)
f"{target} must inherit from {self._base_cls} " return obj
f"to be registered as a {self._name}"
)
self._objects[alias] = target return decorator
# Store metadata if provided # def register_bulk(
if metadata and alias in metadata: # self,
self._metadata[alias] = metadata[alias] # items: dict[str, T | str | md.EntryPoint],
# metadata: dict[str, dict[str, Any]] | None = None,
# ) -> None:
# """Register multiple items at once.
#
# Args:
# items: Dictionary mapping aliases to objects/lazy paths
# metadata: Optional dictionary mapping aliases to metadata
# """
# for alias, target in items.items():
# meta = metadata.get(alias, {}) if metadata else {}
# # For lazy registration, check if it's a string or EntryPoint
# if isinstance(target, (str, md.EntryPoint)):
# self.register(alias, lazy=target, metadata=meta)(None)
# else:
# self.register(alias, metadata=meta)(target)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Lookup & materialisation # Lookup & materialisation
...@@ -211,6 +247,13 @@ class Registry(Generic[T]): ...@@ -211,6 +247,13 @@ class Registry(Generic[T]):
return target # concrete already return target # concrete already
def get(self, alias: str) -> T: def get(self, alias: str) -> T:
# Fast path: check if already materialized without lock
target = self._objects.get(alias)
if target is not None and not isinstance(target, (str, md.EntryPoint)):
# Already materialized and validated, return immediately
return target
# Slow path: acquire lock for materialization
with self._lock: with self._lock:
try: try:
target = self._objects[alias] target = self._objects[alias]
...@@ -220,15 +263,23 @@ class Registry(Generic[T]): ...@@ -220,15 +263,23 @@ class Registry(Generic[T]):
f"{', '.join(self._objects)}" f"{', '.join(self._objects)}"
) from exc ) from exc
# Only materialize if it's a string or EntryPoint (lazy placeholder) # Double-check after acquiring lock (may have been materialized by another thread)
if isinstance(target, (str, md.EntryPoint)): if not isinstance(target, (str, md.EntryPoint)):
return target
# Materialize the lazy placeholder
concrete: T = self._materialise(target) concrete: T = self._materialise(target)
# First‑touch: swap placeholder with concrete obj for future calls
# Swap placeholder with concrete object (with race condition check)
if concrete is not target: if concrete is not target:
# Final check: another thread might have materialized while we were working
current = self._objects.get(alias)
if isinstance(current, (str, md.EntryPoint)):
# Still a placeholder, safe to replace
self._objects[alias] = concrete self._objects[alias] = concrete
else: else:
# Already materialized, just return it # Another thread already materialized it, use their result
concrete = target concrete = current # type: ignore[assignment]
# Late type check (for placeholders) # Late type check (for placeholders)
if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type] if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type]
...@@ -237,8 +288,8 @@ class Registry(Generic[T]): ...@@ -237,8 +288,8 @@ class Registry(Generic[T]):
f"(registered under alias '{alias}')" f"(registered under alias '{alias}')"
) )
# Custom validation # Custom validation - run on materialization
if self._validator is not None and not self._validator(concrete): if self._validator and not self._validator(concrete):
raise ValueError( raise ValueError(
f"{concrete} failed custom validation for {self._name} registry " f"{concrete} failed custom validation for {self._name} registry "
f"(registered under alias '{alias}')" f"(registered under alias '{alias}')"
...@@ -301,7 +352,7 @@ class Registry(Generic[T]): ...@@ -301,7 +352,7 @@ class Registry(Generic[T]):
raise RuntimeError("Cannot clear a frozen registry") raise RuntimeError("Cannot clear a frozen registry")
self._objects.clear() self._objects.clear()
self._metadata.clear() self._metadata.clear()
self._materialise.cache_clear() # type: ignore[attr-defined] # Added by lru_cache self._materialise.cache_clear() # type: ignore[attr-defined]
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
...@@ -327,7 +378,7 @@ class MetricSpec: ...@@ -327,7 +378,7 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402 from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = Registry("model", base_cls=LM) model_registry: Registry[LM] = Registry("model", base_cls=LM)
task_registry: Registry[Callable[..., Any]] = Registry("task") task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric") metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = ( metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = (
...@@ -347,8 +398,31 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -347,8 +398,31 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until": ["exact_match"], "generate_until": ["exact_match"],
} }
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY: dict[str, Callable] = {} def default_metrics_for(output_type: str) -> list[str]:
"""Get default metrics for a given output type dynamically.
This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
"""
# First check static defaults
if output_type in DEFAULT_METRIC_REGISTRY:
return DEFAULT_METRIC_REGISTRY[output_type]
# Walk metric registry for matching output types
matching_metrics = []
for name, metric_spec in metric_registry.items():
if (
isinstance(metric_spec, MetricSpec)
and metric_spec.output_type == output_type
):
matching_metrics.append(name)
return matching_metrics if matching_metrics else []
# Aggregation registry - alias to the canonical registry for backward compatibility
AGGREGATION_REGISTRY = metric_agg_registry # The registry itself is dict-like
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API) # Public helper aliases (legacy API)
...@@ -373,35 +447,39 @@ def register_metric(**kwargs): ...@@ -373,35 +447,39 @@ def register_metric(**kwargs):
if not metric_name: if not metric_name:
raise ValueError("metric name is required") raise ValueError("metric name is required")
# Create MetricSpec with the function and metadata # Determine aggregation function
spec = MetricSpec( aggregate_fn = None
compute=fn, if "aggregation" in kwargs:
aggregate=lambda x: {}, # Default aggregation returns empty dict agg_name = kwargs["aggregation"]
higher_is_better=kwargs.get("higher_is_better", True), try:
output_type=kwargs.get("output_type"), aggregate_fn = metric_agg_registry.get(agg_name)
requires=kwargs.get("requires"), except KeyError:
raise ValueError(f"Unknown aggregation: {agg_name}")
else:
# No aggregation specified - use a function that raises NotImplementedError
def not_implemented_agg(values):
raise NotImplementedError(
f"No aggregation function specified for metric '{metric_name}'. "
"Please specify an 'aggregation' parameter."
) )
# Register in metric registry aggregate_fn = not_implemented_agg
metric_registry._objects[metric_name] = spec
# Also handle aggregation if specified # Create MetricSpec with the function and metadata
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
# Try to get aggregation from AGGREGATION_REGISTRY
if agg_name in AGGREGATION_REGISTRY:
spec = MetricSpec( spec = MetricSpec(
compute=fn, compute=fn,
aggregate=AGGREGATION_REGISTRY[agg_name], aggregate=aggregate_fn,
higher_is_better=kwargs.get("higher_is_better", True), higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"), output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"), requires=kwargs.get("requires"),
) )
metric_registry._objects[metric_name] = spec
# Handle higher_is_better registry # Use proper registry API with metadata
metric_registry.register(metric_name, metadata=kwargs)(spec)
# Also register in higher_is_better registry if specified
if "higher_is_better" in kwargs: if "higher_is_better" in kwargs:
higher_is_better_registry._objects[metric_name] = kwargs["higher_is_better"] higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"])
return fn return fn
...@@ -444,18 +522,22 @@ register_metric_aggregation = metric_agg_registry.register ...@@ -444,18 +522,22 @@ register_metric_aggregation = metric_agg_registry.register
def get_metric_aggregation(metric_name: str): def get_metric_aggregation(metric_name: str):
"""Get the aggregation function for a metric.""" """Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation) # First try to get from metric registry (for metrics registered with aggregation)
if metric_name in metric_registry._objects: try:
metric_spec = metric_registry._objects[metric_name] metric_spec = metric_registry.get(metric_name)
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate: if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate return metric_spec.aggregate
except KeyError:
pass # Try next registry
# Fall back to metric_agg_registry (for standalone aggregations) # Fall back to metric_agg_registry (for standalone aggregations)
if metric_name in metric_agg_registry._objects: try:
return metric_agg_registry._objects[metric_name] return metric_agg_registry.get(metric_name)
except KeyError:
pass
# If not found, raise error # If not found, raise error
raise KeyError( raise KeyError(
f"Unknown metric aggregation '{metric_name}'. Available: {list(AGGREGATION_REGISTRY.keys())}" f"Unknown metric aggregation '{metric_name}'. Available: {list(metric_agg_registry)}"
) )
...@@ -468,20 +550,30 @@ get_filter = filter_registry.get ...@@ -468,20 +550,30 @@ get_filter = filter_registry.get
# Special handling for AGGREGATION_REGISTRY which works differently # Special handling for AGGREGATION_REGISTRY which works differently
def register_aggregation(name: str): def register_aggregation(name: str):
"""@deprecated Use metric_agg_registry.register() instead."""
warnings.warn(
"register_aggregation() is deprecated. Use metric_agg_registry.register() instead.",
DeprecationWarning,
stacklevel=2,
)
def decorate(fn): def decorate(fn):
if name in AGGREGATION_REGISTRY: # Use the canonical registry as single source of truth
if name in metric_agg_registry:
raise ValueError( raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!" f"aggregation named '{name}' conflicts with existing registered aggregation!"
) )
AGGREGATION_REGISTRY[name] = fn metric_agg_registry.register(name)(fn)
return fn return fn
return decorate return decorate
def get_aggregation(name: str) -> Callable[[], dict[str, Callable]]: def get_aggregation(name: str) -> Callable[[Iterable[Any]], Mapping[str, float]] | None:
"""@deprecated Use metric_agg_registry.get() instead."""
try: try:
return AGGREGATION_REGISTRY[name] # Use the canonical registry
return metric_agg_registry.get(name)
except KeyError: except KeyError:
import logging import logging
...@@ -526,15 +618,17 @@ def freeze_all() -> None: # pragma: no cover ...@@ -526,15 +618,17 @@ def freeze_all() -> None: # pragma: no cover
# Backwards‑compatibility read‑only globals # Backwards‑compatibility read‑only globals
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY: Mapping[str, type[LM]] = MappingProxyType(model_registry._objects) # type: ignore[attr-defined] # These are direct aliases to the registries themselves, which already implement
TASK_REGISTRY: Mapping[str, Callable[..., Any]] = MappingProxyType( # the Mapping protocol and provide read-only access to users (since _objects is private).
task_registry._objects # This ensures they always reflect the current state of the registries, including
) # type: ignore[attr-defined] # items registered after module import.
METRIC_REGISTRY: Mapping[str, MetricSpec] = MappingProxyType(metric_registry._objects) # type: ignore[attr-defined] #
METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = MappingProxyType( # Note: We use type: ignore because Registry doesn't formally inherit from Mapping,
metric_agg_registry._objects # but it implements all required methods (__getitem__, __iter__, __len__, items)
) # type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = MappingProxyType( MODEL_REGISTRY: Mapping[str, LM] = model_registry # type: ignore[assignment]
higher_is_better_registry._objects TASK_REGISTRY: Mapping[str, Callable[..., Any]] = task_registry # type: ignore[assignment]
) # type: ignore[attr-defined] METRIC_REGISTRY: Mapping[str, MetricSpec] = metric_registry # type: ignore[assignment]
FILTER_REGISTRY: Mapping[str, Callable] = MappingProxyType(filter_registry._objects) # type: ignore[attr-defined] METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = metric_agg_registry # type: ignore[assignment]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = higher_is_better_registry # type: ignore[assignment]
FILTER_REGISTRY: Mapping[str, Callable] = filter_registry # type: ignore[assignment]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment