Commit 48eabc04 authored by Baber's avatar Baber
Browse files

add better type safety

parent 93b2ab37
......@@ -3,6 +3,7 @@ from __future__ import annotations
import importlib
import inspect
import threading
import warnings
from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass
from functools import lru_cache
......@@ -12,6 +13,7 @@ from typing import (
Callable,
Generic,
TypeVar,
overload,
)
......@@ -92,104 +94,138 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
@overload
def register(
self,
*aliases: str,
lazy: str | md.EntryPoint | None = None,
lazy: None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def _do_register(target: T | str | md.EntryPoint) -> None:
if not aliases:
_aliases = (getattr(target, "__name__", str(target)),)
else:
_aliases = aliases
with self._lock:
for alias in _aliases:
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
)
self._objects[alias] = target
# Store metadata if provided
if metadata:
self._metadata[alias] = metadata
"""Register as decorator: @registry.register("foo")."""
...
# ─── decorator path ───
def decorator(obj: T) -> T: # type: ignore[valid-type]
_do_register(obj)
return obj
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
_do_register(lazy)
return lambda x: x # no‑op decorator for accidental use
return decorator
def register_bulk(
@overload
def register(
self,
items: dict[str, T | str | md.EntryPoint],
metadata: dict[str, dict[str, Any]] | None = None,
*aliases: str,
lazy: str | md.EntryPoint,
metadata: dict[str, Any] | None = None,
) -> Callable[[Any], Any]:
"""Register lazy: registry.register("foo", lazy="a.b:C")(None)."""
...
def _resolve_aliases(
self, target: T | str | md.EntryPoint, aliases: tuple[str, ...]
) -> tuple[str, ...]:
"""Resolve aliases for registration."""
if not aliases:
return (getattr(target, "__name__", str(target)),)
return aliases
def _check_and_store(
self,
alias: str,
target: T | str | md.EntryPoint,
metadata: dict[str, Any] | None,
) -> None:
"""Register multiple items at once.
"""Check constraints and store the target with optional metadata.
Args:
items: Dictionary mapping aliases to objects/lazy paths
metadata: Optional dictionary mapping aliases to metadata
Collision policy:
1. If alias doesn't exist → store it
2. If identical value → silently succeed (idempotent)
3. If lazy placeholder + matching concrete class → replace with concrete
4. Otherwise → raise ValueError
Type checking:
- Eager for concrete classes at registration time
- Deferred for lazy placeholders until materialization
"""
with self._lock:
for alias, target in items.items():
if alias in self._objects:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing = self._objects[alias]
if isinstance(existing, (str, md.EntryPoint)) and isinstance(
target, type
):
# Allow replacing lazy placeholder with concrete class
pass
else:
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"({self._objects[alias]})"
)
# Eager type check only when we have a concrete class
# Case 1: New alias
if alias not in self._objects:
# Type check concrete classes before storing
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} "
f"to be registered as a {self._name}"
)
self._objects[alias] = target
if metadata:
self._metadata[alias] = metadata
return
existing = self._objects[alias]
# Case 2: Identical value - idempotent
if existing == target:
return
# Case 3: Lazy placeholder being replaced by its concrete class
if isinstance(existing, str) and isinstance(target, type):
mod_path, _, cls_name = existing.partition(":")
if (
cls_name
and hasattr(target, "__module__")
and hasattr(target, "__name__")
):
expected_path = f"{target.__module__}:{target.__name__}"
if existing == expected_path:
self._objects[alias] = target
if metadata:
self._metadata[alias] = metadata
return
# Case 4: Collision - different values
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"(existing: {existing}, new: {target})"
)
def register(
self,
*aliases: str,
lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
for alias in self._resolve_aliases(lazy, aliases):
self._check_and_store(alias, lazy, metadata)
return lambda x: x # no‑op decorator for accidental use
# Store metadata if provided
if metadata and alias in metadata:
self._metadata[alias] = metadata[alias]
# ─── decorator path ───
def decorator(obj: T) -> T: # type: ignore[valid-type]
for alias in self._resolve_aliases(obj, aliases):
self._check_and_store(alias, obj, metadata)
return obj
return decorator
# def register_bulk(
# self,
# items: dict[str, T | str | md.EntryPoint],
# metadata: dict[str, dict[str, Any]] | None = None,
# ) -> None:
# """Register multiple items at once.
#
# Args:
# items: Dictionary mapping aliases to objects/lazy paths
# metadata: Optional dictionary mapping aliases to metadata
# """
# for alias, target in items.items():
# meta = metadata.get(alias, {}) if metadata else {}
# # For lazy registration, check if it's a string or EntryPoint
# if isinstance(target, (str, md.EntryPoint)):
# self.register(alias, lazy=target, metadata=meta)(None)
# else:
# self.register(alias, metadata=meta)(target)
# ------------------------------------------------------------------
# Lookup & materialisation
......@@ -211,6 +247,13 @@ class Registry(Generic[T]):
return target # concrete already
def get(self, alias: str) -> T:
# Fast path: check if already materialized without lock
target = self._objects.get(alias)
if target is not None and not isinstance(target, (str, md.EntryPoint)):
# Already materialized and validated, return immediately
return target
# Slow path: acquire lock for materialization
with self._lock:
try:
target = self._objects[alias]
......@@ -220,15 +263,23 @@ class Registry(Generic[T]):
f"{', '.join(self._objects)}"
) from exc
# Only materialize if it's a string or EntryPoint (lazy placeholder)
if isinstance(target, (str, md.EntryPoint)):
concrete: T = self._materialise(target)
# First‑touch: swap placeholder with concrete obj for future calls
if concrete is not target:
# Double-check after acquiring lock (may have been materialized by another thread)
if not isinstance(target, (str, md.EntryPoint)):
return target
# Materialize the lazy placeholder
concrete: T = self._materialise(target)
# Swap placeholder with concrete object (with race condition check)
if concrete is not target:
# Final check: another thread might have materialized while we were working
current = self._objects.get(alias)
if isinstance(current, (str, md.EntryPoint)):
# Still a placeholder, safe to replace
self._objects[alias] = concrete
else:
# Already materialized, just return it
concrete = target
else:
# Another thread already materialized it, use their result
concrete = current # type: ignore[assignment]
# Late type check (for placeholders)
if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type]
......@@ -237,8 +288,8 @@ class Registry(Generic[T]):
f"(registered under alias '{alias}')"
)
# Custom validation
if self._validator is not None and not self._validator(concrete):
# Custom validation - run on materialization
if self._validator and not self._validator(concrete):
raise ValueError(
f"{concrete} failed custom validation for {self._name} registry "
f"(registered under alias '{alias}')"
......@@ -301,7 +352,7 @@ class Registry(Generic[T]):
raise RuntimeError("Cannot clear a frozen registry")
self._objects.clear()
self._metadata.clear()
self._materialise.cache_clear() # type: ignore[attr-defined] # Added by lru_cache
self._materialise.cache_clear() # type: ignore[attr-defined]
# ────────────────────────────────────────────────────────────────────────
......@@ -327,7 +378,7 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = Registry("model", base_cls=LM)
model_registry: Registry[LM] = Registry("model", base_cls=LM)
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = (
......@@ -347,8 +398,31 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until": ["exact_match"],
}
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY: dict[str, Callable] = {}
def default_metrics_for(output_type: str) -> list[str]:
"""Get default metrics for a given output type dynamically.
This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
"""
# First check static defaults
if output_type in DEFAULT_METRIC_REGISTRY:
return DEFAULT_METRIC_REGISTRY[output_type]
# Walk metric registry for matching output types
matching_metrics = []
for name, metric_spec in metric_registry.items():
if (
isinstance(metric_spec, MetricSpec)
and metric_spec.output_type == output_type
):
matching_metrics.append(name)
return matching_metrics if matching_metrics else []
# Aggregation registry - alias to the canonical registry for backward compatibility
AGGREGATION_REGISTRY = metric_agg_registry # The registry itself is dict-like
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
......@@ -373,35 +447,39 @@ def register_metric(**kwargs):
if not metric_name:
raise ValueError("metric name is required")
# Determine aggregation function
aggregate_fn = None
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
try:
aggregate_fn = metric_agg_registry.get(agg_name)
except KeyError:
raise ValueError(f"Unknown aggregation: {agg_name}")
else:
# No aggregation specified - use a function that raises NotImplementedError
def not_implemented_agg(values):
raise NotImplementedError(
f"No aggregation function specified for metric '{metric_name}'. "
"Please specify an 'aggregation' parameter."
)
aggregate_fn = not_implemented_agg
# Create MetricSpec with the function and metadata
spec = MetricSpec(
compute=fn,
aggregate=lambda x: {}, # Default aggregation returns empty dict
aggregate=aggregate_fn,
higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"),
)
# Register in metric registry
metric_registry._objects[metric_name] = spec
# Use proper registry API with metadata
metric_registry.register(metric_name, metadata=kwargs)(spec)
# Also handle aggregation if specified
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
# Try to get aggregation from AGGREGATION_REGISTRY
if agg_name in AGGREGATION_REGISTRY:
spec = MetricSpec(
compute=fn,
aggregate=AGGREGATION_REGISTRY[agg_name],
higher_is_better=kwargs.get("higher_is_better", True),
output_type=kwargs.get("output_type"),
requires=kwargs.get("requires"),
)
metric_registry._objects[metric_name] = spec
# Handle higher_is_better registry
# Also register in higher_is_better registry if specified
if "higher_is_better" in kwargs:
higher_is_better_registry._objects[metric_name] = kwargs["higher_is_better"]
higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"])
return fn
......@@ -444,18 +522,22 @@ register_metric_aggregation = metric_agg_registry.register
def get_metric_aggregation(metric_name: str):
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
if metric_name in metric_registry._objects:
metric_spec = metric_registry._objects[metric_name]
try:
metric_spec = metric_registry.get(metric_name)
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate
except KeyError:
pass # Try next registry
# Fall back to metric_agg_registry (for standalone aggregations)
if metric_name in metric_agg_registry._objects:
return metric_agg_registry._objects[metric_name]
try:
return metric_agg_registry.get(metric_name)
except KeyError:
pass
# If not found, raise error
raise KeyError(
f"Unknown metric aggregation '{metric_name}'. Available: {list(AGGREGATION_REGISTRY.keys())}"
f"Unknown metric aggregation '{metric_name}'. Available: {list(metric_agg_registry)}"
)
......@@ -468,20 +550,30 @@ get_filter = filter_registry.get
# Special handling for AGGREGATION_REGISTRY which works differently
def register_aggregation(name: str):
"""@deprecated Use metric_agg_registry.register() instead."""
warnings.warn(
"register_aggregation() is deprecated. Use metric_agg_registry.register() instead.",
DeprecationWarning,
stacklevel=2,
)
def decorate(fn):
if name in AGGREGATION_REGISTRY:
# Use the canonical registry as single source of truth
if name in metric_agg_registry:
raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY[name] = fn
metric_agg_registry.register(name)(fn)
return fn
return decorate
def get_aggregation(name: str) -> Callable[[], dict[str, Callable]]:
def get_aggregation(name: str) -> Callable[[Iterable[Any]], Mapping[str, float]] | None:
"""@deprecated Use metric_agg_registry.get() instead."""
try:
return AGGREGATION_REGISTRY[name]
# Use the canonical registry
return metric_agg_registry.get(name)
except KeyError:
import logging
......@@ -526,15 +618,17 @@ def freeze_all() -> None: # pragma: no cover
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY: Mapping[str, type[LM]] = MappingProxyType(model_registry._objects) # type: ignore[attr-defined]
TASK_REGISTRY: Mapping[str, Callable[..., Any]] = MappingProxyType(
task_registry._objects
) # type: ignore[attr-defined]
METRIC_REGISTRY: Mapping[str, MetricSpec] = MappingProxyType(metric_registry._objects) # type: ignore[attr-defined]
METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = MappingProxyType(
metric_agg_registry._objects
) # type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = MappingProxyType(
higher_is_better_registry._objects
) # type: ignore[attr-defined]
FILTER_REGISTRY: Mapping[str, Callable] = MappingProxyType(filter_registry._objects) # type: ignore[attr-defined]
# These are direct aliases to the registries themselves, which already implement
# the Mapping protocol and provide read-only access to users (since _objects is private).
# This ensures they always reflect the current state of the registries, including
# items registered after module import.
#
# Note: We use type: ignore because Registry doesn't formally inherit from Mapping,
# but it implements all required methods (__getitem__, __iter__, __len__, items)
MODEL_REGISTRY: Mapping[str, LM] = model_registry # type: ignore[assignment]
TASK_REGISTRY: Mapping[str, Callable[..., Any]] = task_registry # type: ignore[assignment]
METRIC_REGISTRY: Mapping[str, MetricSpec] = metric_registry # type: ignore[assignment]
METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = metric_agg_registry # type: ignore[assignment]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = higher_is_better_registry # type: ignore[assignment]
FILTER_REGISTRY: Mapping[str, Callable] = filter_registry # type: ignore[assignment]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment