Commit 48eabc04 authored by Baber's avatar Baber
Browse files

add better type safety

parent 93b2ab37
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
import importlib import importlib
import inspect import inspect
import threading import threading
import warnings
from collections.abc import Iterable, Mapping, MutableMapping from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
...@@ -12,6 +13,7 @@ from typing import ( ...@@ -12,6 +13,7 @@ from typing import (
Callable, Callable,
Generic, Generic,
TypeVar, TypeVar,
overload,
) )
...@@ -92,104 +94,138 @@ class Registry(Generic[T]): ...@@ -92,104 +94,138 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call) # Registration helpers (decorator or direct call)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@overload
def register( def register(
self, self,
*aliases: str, *aliases: str,
lazy: str | md.EntryPoint | None = None, lazy: None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]: ) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``. """Register as decorator: @registry.register("foo")."""
...
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def _do_register(target: T | str | md.EntryPoint) -> None:
if not aliases:
_aliases = (getattr(target, "__name__", str(target)),)
else:
_aliases = aliases
with self._lock:
for alias in _aliases:
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
)
self._objects[alias] = target
# Store metadata if provided
if metadata:
self._metadata[alias] = metadata
# ─── decorator path ─── @overload
def decorator(obj: T) -> T: # type: ignore[valid-type] def register(
_do_register(obj)
return obj
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
_do_register(lazy)
return lambda x: x # no‑op decorator for accidental use
return decorator
def register_bulk(
self, self,
items: dict[str, T | str | md.EntryPoint], *aliases: str,
metadata: dict[str, dict[str, Any]] | None = None, lazy: str | md.EntryPoint,
metadata: dict[str, Any] | None = None,
) -> Callable[[Any], Any]:
"""Register lazy: registry.register("foo", lazy="a.b:C")(None)."""
...
def _resolve_aliases(
self, target: T | str | md.EntryPoint, aliases: tuple[str, ...]
) -> tuple[str, ...]:
"""Resolve aliases for registration."""
if not aliases:
return (getattr(target, "__name__", str(target)),)
return aliases
def _check_and_store(
self,
alias: str,
target: T | str | md.EntryPoint,
metadata: dict[str, Any] | None,
) -> None: ) -> None:
"""Register multiple items at once. """Check constraints and store the target with optional metadata.
Args: Collision policy:
items: Dictionary mapping aliases to objects/lazy paths 1. If alias doesn't exist → store it
metadata: Optional dictionary mapping aliases to metadata 2. If identical value → silently succeed (idempotent)
3. If lazy placeholder + matching concrete class → replace with concrete
4. Otherwise → raise ValueError
Type checking:
- Eager for concrete classes at registration time
- Deferred for lazy placeholders until materialization
""" """
with self._lock: with self._lock:
for alias, target in items.items(): # Case 1: New alias
if alias in self._objects: if alias not in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it # Type check concrete classes before storing
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type): if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type] if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError( raise TypeError(
f"{target} must inherit from {self._base_cls} " f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}" f"to be registered as a {self._name}"
) )
self._objects[alias] = target self._objects[alias] = target
if metadata:
self._metadata[alias] = metadata
return
existing = self._objects[alias]
# Case 2: Identical value - idempotent
if existing == target:
return
# Case 3: Lazy placeholder being replaced by its concrete class
if isinstance(existing, str) and isinstance(target, type):
mod_path, _, cls_name = existing.partition(":")
if (
cls_name
and hasattr(target, "__module__")
and hasattr(target, "__name__")
):
expected_path = f"{target.__module__}:{target.__name__}"
if existing == expected_path:
self._objects[alias] = target
if metadata:
self._metadata[alias] = metadata
return
# Case 4: Collision - different values
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"(existing: {existing}, new: {target})"
)
def register(
self,
*aliases: str,
lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
for alias in self._resolve_aliases(lazy, aliases):
self._check_and_store(alias, lazy, metadata)
return lambda x: x # no‑op decorator for accidental use
# Store metadata if provided # ─── decorator path ───
if metadata and alias in metadata: def decorator(obj: T) -> T: # type: ignore[valid-type]
self._metadata[alias] = metadata[alias] for alias in self._resolve_aliases(obj, aliases):
self._check_and_store(alias, obj, metadata)
return obj
return decorator
# def register_bulk(
# self,
# items: dict[str, T | str | md.EntryPoint],
# metadata: dict[str, dict[str, Any]] | None = None,
# ) -> None:
# """Register multiple items at once.
#
# Args:
# items: Dictionary mapping aliases to objects/lazy paths
# metadata: Optional dictionary mapping aliases to metadata
# """
# for alias, target in items.items():
# meta = metadata.get(alias, {}) if metadata else {}
# # For lazy registration, check if it's a string or EntryPoint
# if isinstance(target, (str, md.EntryPoint)):
# self.register(alias, lazy=target, metadata=meta)(None)
# else:
# self.register(alias, metadata=meta)(target)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Lookup & materialisation # Lookup & materialisation
...@@ -211,6 +247,13 @@ class Registry(Generic[T]): ...@@ -211,6 +247,13 @@ class Registry(Generic[T]):
return target # concrete already return target # concrete already
def get(self, alias: str) -> T: def get(self, alias: str) -> T:
# Fast path: check if already materialized without lock
target = self._objects.get(alias)
if target is not None and not isinstance(target, (str, md.EntryPoint)):
# Already materialized and validated, return immediately
return target
# Slow path: acquire lock for materialization
with self._lock: with self._lock:
try: try:
target = self._objects[alias] target = self._objects[alias]
...@@ -220,15 +263,23 @@ class Registry(Generic[T]): ...@@ -220,15 +263,23 @@ class Registry(Generic[T]):
f"{', '.join(self._objects)}" f"{', '.join(self._objects)}"
) from exc ) from exc
# Only materialize if it's a string or EntryPoint (lazy placeholder) # Double-check after acquiring lock (may have been materialized by another thread)
if isinstance(target, (str, md.EntryPoint)): if not isinstance(target, (str, md.EntryPoint)):
concrete: T = self._materialise(target) return target
# First‑touch: swap placeholder with concrete obj for future calls
if concrete is not target: # Materialize the lazy placeholder
concrete: T = self._materialise(target)
# Swap placeholder with concrete object (with race condition check)
if concrete is not target:
# Final check: another thread might have materialized while we were working
current = self._objects.get(alias)
if isinstance(current, (str, md.EntryPoint)):
# Still a placeholder, safe to replace
self._objects[alias] = concrete self._objects[alias] = concrete
else: else:
# Already materialized, just return it # Another thread already materialized it, use their result
concrete = target concrete = current # type: ignore[assignment]
# Late type check (for placeholders) # Late type check (for placeholders)
if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type] if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type]
...@@ -237,8 +288,8 @@ class Registry(Generic[T]): ...@@ -237,8 +288,8 @@ class Registry(Generic[T]):
f"(registered under alias '{alias}')" f"(registered under alias '{alias}')"
) )
# Custom validation # Custom validation - run on materialization
if self._validator is not None and not self._validator(concrete): if self._validator and not self._validator(concrete):
raise ValueError( raise ValueError(
f"{concrete} failed custom validation for {self._name} registry " f"{concrete} failed custom validation for {self._name} registry "
f"(registered under alias '{alias}')" f"(registered under alias '{alias}')"
...@@ -301,7 +352,7 @@ class Registry(Generic[T]): ...@@ -301,7 +352,7 @@ class Registry(Generic[T]):
raise RuntimeError("Cannot clear a frozen registry") raise RuntimeError("Cannot clear a frozen registry")
self._objects.clear() self._objects.clear()
self._metadata.clear() self._metadata.clear()
self._materialise.cache_clear() # type: ignore[attr-defined] # Added by lru_cache self._materialise.cache_clear() # type: ignore[attr-defined]
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
...@@ -327,7 +378,7 @@ class MetricSpec: ...@@ -327,7 +378,7 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402 from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = Registry("model", base_cls=LM) model_registry: Registry[LM] = Registry("model", base_cls=LM)
task_registry: Registry[Callable[..., Any]] = Registry("task") task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric") metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = ( metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = (
...@@ -347,8 +398,31 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -347,8 +398,31 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until": ["exact_match"], "generate_until": ["exact_match"],
} }
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY: dict[str, Callable] = {} def default_metrics_for(output_type: str) -> list[str]:
"""Get default metrics for a given output type dynamically.
This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
"""
# First check static defaults
if output_type in DEFAULT_METRIC_REGISTRY:
return DEFAULT_METRIC_REGISTRY[output_type]
# Walk metric registry for matching output types
matching_metrics = []
for name, metric_spec in metric_registry.items():
if (
isinstance(metric_spec, MetricSpec)
and metric_spec.output_type == output_type
):
matching_metrics.append(name)
return matching_metrics if matching_metrics else []
# Aggregation registry - alias to the canonical registry for backward compatibility
AGGREGATION_REGISTRY = metric_agg_registry # The registry itself is dict-like
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API) # Public helper aliases (legacy API)
...@@ -373,35 +447,39 @@ def register_metric(**kwargs): ...@@ -373,35 +447,39 @@ def register_metric(**kwargs):
if not metric_name: if not metric_name:
raise ValueError("metric name is required") raise ValueError("metric name is required")
# Determine aggregation function
aggregate_fn = None
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
try:
aggregate_fn = metric_agg_registry.get(agg_name)
except KeyError:
raise ValueError(f"Unknown aggregation: {agg_name}")
else:
# No aggregation specified - use a function that raises NotImplementedError
def not_implemented_agg(values):
raise NotImplementedError(
f"No aggregation function specified for metric '{metric_name}'. "
"Please specify an 'aggregation' parameter."
)
aggregate_fn = not_implemented_agg
# Create MetricSpec with the function and metadata # Create MetricSpec with the function and metadata
spec = MetricSpec( spec = MetricSpec(
compute=fn, compute=fn,
aggregate=lambda x: {}, # Default aggregation returns empty dict aggregate=aggregate_fn,
higher_is_better=kwargs.get("higher_is_better", True), higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"), output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"), requires=kwargs.get("requires"),
) )
# Register in metric registry # Use proper registry API with metadata
metric_registry._objects[metric_name] = spec metric_registry.register(metric_name, metadata=kwargs)(spec)
# Also handle aggregation if specified # Also register in higher_is_better registry if specified
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
# Try to get aggregation from AGGREGATION_REGISTRY
if agg_name in AGGREGATION_REGISTRY:
spec = MetricSpec(
compute=fn,
aggregate=AGGREGATION_REGISTRY[agg_name],
higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"),
)
metric_registry._objects[metric_name] = spec
# Handle higher_is_better registry
if "higher_is_better" in kwargs: if "higher_is_better" in kwargs:
higher_is_better_registry._objects[metric_name] = kwargs["higher_is_better"] higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"])
return fn return fn
...@@ -444,18 +522,22 @@ register_metric_aggregation = metric_agg_registry.register ...@@ -444,18 +522,22 @@ register_metric_aggregation = metric_agg_registry.register
def get_metric_aggregation(metric_name: str): def get_metric_aggregation(metric_name: str):
"""Get the aggregation function for a metric.""" """Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation) # First try to get from metric registry (for metrics registered with aggregation)
if metric_name in metric_registry._objects: try:
metric_spec = metric_registry._objects[metric_name] metric_spec = metric_registry.get(metric_name)
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate: if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate return metric_spec.aggregate
except KeyError:
pass # Try next registry
# Fall back to metric_agg_registry (for standalone aggregations) # Fall back to metric_agg_registry (for standalone aggregations)
if metric_name in metric_agg_registry._objects: try:
return metric_agg_registry._objects[metric_name] return metric_agg_registry.get(metric_name)
except KeyError:
pass
# If not found, raise error # If not found, raise error
raise KeyError( raise KeyError(
f"Unknown metric aggregation '{metric_name}'. Available: {list(AGGREGATION_REGISTRY.keys())}" f"Unknown metric aggregation '{metric_name}'. Available: {list(metric_agg_registry)}"
) )
...@@ -468,20 +550,30 @@ get_filter = filter_registry.get ...@@ -468,20 +550,30 @@ get_filter = filter_registry.get
# Special handling for AGGREGATION_REGISTRY which works differently # Special handling for AGGREGATION_REGISTRY which works differently
def register_aggregation(name: str): def register_aggregation(name: str):
"""@deprecated Use metric_agg_registry.register() instead."""
warnings.warn(
"register_aggregation() is deprecated. Use metric_agg_registry.register() instead.",
DeprecationWarning,
stacklevel=2,
)
def decorate(fn): def decorate(fn):
if name in AGGREGATION_REGISTRY: # Use the canonical registry as single source of truth
if name in metric_agg_registry:
raise ValueError( raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!" f"aggregation named '{name}' conflicts with existing registered aggregation!"
) )
AGGREGATION_REGISTRY[name] = fn metric_agg_registry.register(name)(fn)
return fn return fn
return decorate return decorate
def get_aggregation(name: str) -> Callable[[], dict[str, Callable]]: def get_aggregation(name: str) -> Callable[[Iterable[Any]], Mapping[str, float]] | None:
"""@deprecated Use metric_agg_registry.get() instead."""
try: try:
return AGGREGATION_REGISTRY[name] # Use the canonical registry
return metric_agg_registry.get(name)
except KeyError: except KeyError:
import logging import logging
...@@ -526,15 +618,17 @@ def freeze_all() -> None: # pragma: no cover ...@@ -526,15 +618,17 @@ def freeze_all() -> None: # pragma: no cover
# Backwards‑compatibility read‑only globals # Backwards‑compatibility read‑only globals
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY: Mapping[str, type[LM]] = MappingProxyType(model_registry._objects) # type: ignore[attr-defined] # These are direct aliases to the registries themselves, which already implement
TASK_REGISTRY: Mapping[str, Callable[..., Any]] = MappingProxyType( # the Mapping protocol and provide read-only access to users (since _objects is private).
task_registry._objects # This ensures they always reflect the current state of the registries, including
) # type: ignore[attr-defined] # items registered after module import.
METRIC_REGISTRY: Mapping[str, MetricSpec] = MappingProxyType(metric_registry._objects) # type: ignore[attr-defined] #
METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = MappingProxyType( # Note: We use type: ignore because Registry doesn't formally inherit from Mapping,
metric_agg_registry._objects # but it implements all required methods (__getitem__, __iter__, __len__, items)
) # type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = MappingProxyType( MODEL_REGISTRY: Mapping[str, LM] = model_registry # type: ignore[assignment]
higher_is_better_registry._objects TASK_REGISTRY: Mapping[str, Callable[..., Any]] = task_registry # type: ignore[assignment]
) # type: ignore[attr-defined] METRIC_REGISTRY: Mapping[str, MetricSpec] = metric_registry # type: ignore[assignment]
FILTER_REGISTRY: Mapping[str, Callable] = MappingProxyType(filter_registry._objects) # type: ignore[attr-defined] METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = metric_agg_registry # type: ignore[assignment]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = higher_is_better_registry # type: ignore[assignment]
FILTER_REGISTRY: Mapping[str, Callable] = filter_registry # type: ignore[assignment]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment