Commit e9451269 authored by Baber's avatar Baber
Browse files

cleanup and and add types

parent 48eabc04
from __future__ import annotations
import functools
import importlib
import inspect
import threading
......@@ -13,7 +14,7 @@ from typing import (
Callable,
Generic,
TypeVar,
overload,
cast,
)
......@@ -22,19 +23,8 @@ try: # Python≥3.10
except ImportError: # pragma: no cover - fallback for 3.8/3.9 runtimes
import importlib_metadata as md # type: ignore
__all__ = [
"Registry",
"MetricSpec",
# concrete registries
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
# helper
"freeze_all",
# Legacy compatibility
# Legacy exports (keep for one release, then drop)
LEGACY_EXPORTS = [
"DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY",
"register_model",
......@@ -59,6 +49,21 @@ __all__ = [
"FILTER_REGISTRY",
]
__all__ = [
# canonical
"Registry",
"MetricSpec",
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
"freeze_all",
# legacy
*LEGACY_EXPORTS,
]
T = TypeVar("T")
......@@ -94,25 +99,25 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
@overload
def register(
self,
*aliases: str,
lazy: None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""Register as decorator: @registry.register("foo")."""
...
@overload
def register(
self,
*aliases: str,
lazy: str | md.EntryPoint,
metadata: dict[str, Any] | None = None,
) -> Callable[[Any], Any]:
"""Register lazy: registry.register("foo", lazy="a.b:C")(None)."""
...
# @overload
# def register(
# self,
# *aliases: str,
# lazy: None = None,
# metadata: dict[str, Any] | None = None,
# ) -> Callable[[T], T]:
# """Register as decorator: @registry.register("foo")."""
# ...
#
# @overload
# def register(
# self,
# *aliases: str,
# lazy: str | md.EntryPoint,
# metadata: dict[str, Any] | None = None,
# ) -> Callable[[Any], Any]:
# """Register lazy: registry.register("foo", lazy="a.b:C")"""
# ...
def _resolve_aliases(
self, target: T | str | md.EntryPoint, aliases: tuple[str, ...]
......@@ -185,47 +190,25 @@ class Registry(Generic[T]):
def register(
self,
*aliases: str,
obj: T | None = None,
lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
for alias in self._resolve_aliases(lazy, aliases):
self._check_and_store(alias, lazy, metadata)
return lambda x: x # no‑op decorator for accidental use
):
if obj and lazy:
raise ValueError("pass obj *or* lazy")
# ─── decorator path ───
def decorator(obj: T) -> T: # type: ignore[valid-type]
for alias in self._resolve_aliases(obj, aliases):
self._check_and_store(alias, obj, metadata)
return obj
@functools.wraps(self.register)
def _impl(target: T | str | md.EntryPoint):
for a in aliases or (getattr(target, "__name__", str(target)),):
self._check_and_store(a, target, metadata)
return target
return decorator
# imperative call → immediately registers and returns the target
if obj is not None or lazy is not None:
return _impl(obj if obj is not None else lazy) # type: ignore[arg-type]
# def register_bulk(
# self,
# items: dict[str, T | str | md.EntryPoint],
# metadata: dict[str, dict[str, Any]] | None = None,
# ) -> None:
# """Register multiple items at once.
#
# Args:
# items: Dictionary mapping aliases to objects/lazy paths
# metadata: Optional dictionary mapping aliases to metadata
# """
# for alias, target in items.items():
# meta = metadata.get(alias, {}) if metadata else {}
# # For lazy registration, check if it's a string or EntryPoint
# if isinstance(target, (str, md.EntryPoint)):
# self.register(alias, lazy=target, metadata=meta)(None)
# else:
# self.register(alias, metadata=meta)(target)
# decorator call → return function that will later receive the object
return _impl
# ------------------------------------------------------------------
# Lookup & materialisation
......@@ -241,9 +224,9 @@ class Registry(Generic[T]):
f"Lazy path '{target}' must be in 'module:object' form"
)
module = importlib.import_module(mod)
return getattr(module, obj_name)
return cast(T, getattr(module, obj_name))
if isinstance(target, md.EntryPoint):
return target.load()
return cast(T, target.load())
return target # concrete already
def get(self, alias: str) -> T:
......@@ -263,14 +246,14 @@ class Registry(Generic[T]):
f"{', '.join(self._objects)}"
) from exc
# Double-check after acquiring lock (may have been materialized by another thread)
# Double-check after acquiring a lock (may have been materialized by another thread)
if not isinstance(target, (str, md.EntryPoint)):
return target
# Materialize the lazy placeholder
concrete: T = self._materialise(target)
# Swap placeholder with concrete object (with race condition check)
# Swap placeholder with a concrete object (with race condition check)
if concrete is not target:
# Final check: another thread might have materialized while we were working
current = self._objects.get(alias)
......@@ -405,7 +388,7 @@ def default_metrics_for(output_type: str) -> list[str]:
This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
"""
# First check static defaults
# First, check static defaults
if output_type in DEFAULT_METRIC_REGISTRY:
return DEFAULT_METRIC_REGISTRY[output_type]
......@@ -448,7 +431,7 @@ def register_metric(**kwargs):
raise ValueError("metric name is required")
# Determine aggregation function
aggregate_fn = None
aggregate_fn: Callable[[Iterable[Any]], Mapping[str, float]] | None = None
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
try:
......@@ -474,12 +457,12 @@ def register_metric(**kwargs):
requires=kwargs.get("requires"),
)
# Use proper registry API with metadata
metric_registry.register(metric_name, metadata=kwargs)(spec)
# Use a proper registry API with metadata
metric_registry.register(metric_name, metadata=kwargs)(spec) # type: ignore[misc]
# Also register in higher_is_better registry if specified
if "higher_is_better" in kwargs:
higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"])
higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"]) # type: ignore[misc]
return fn
......@@ -519,15 +502,17 @@ def get_metric(name: str, hf_evaluate_metric=False):
register_metric_aggregation = metric_agg_registry.register
def get_metric_aggregation(metric_name: str):
def get_metric_aggregation(
metric_name: str,
) -> Callable[[Iterable[Any]], Mapping[str, float]]:
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
# First, try to get from the metric registry (for metrics registered with aggregation)
try:
metric_spec = metric_registry.get(metric_name)
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate
except KeyError:
pass # Try next registry
pass # Try the next registry
# Fall back to metric_agg_registry (for standalone aggregations)
try:
......@@ -535,7 +520,7 @@ def get_metric_aggregation(metric_name: str):
except KeyError:
pass
# If not found, raise error
# If not found, raise an error
raise KeyError(
f"Unknown metric aggregation '{metric_name}'. Available: {list(metric_agg_registry)}"
)
......@@ -558,12 +543,12 @@ def register_aggregation(name: str):
)
def decorate(fn):
# Use the canonical registry as single source of truth
# Use the canonical registry as a single source of truth
if name in metric_agg_registry:
raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
metric_agg_registry.register(name)(fn)
metric_agg_registry.register(name)(fn) # type: ignore[misc]
return fn
return decorate
......
......@@ -42,7 +42,7 @@ def _register_all_models():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Call register with the lazy parameter, returns a decorator
model_registry.register(name, lazy=path)(None)
model_registry.register(name, lazy=path)
# Call registration on module import
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment