Commit e9451269 authored by Baber's avatar Baber
Browse files

cleanup and and add types

parent 48eabc04
from __future__ import annotations from __future__ import annotations
import functools
import importlib import importlib
import inspect import inspect
import threading import threading
...@@ -13,7 +14,7 @@ from typing import ( ...@@ -13,7 +14,7 @@ from typing import (
Callable, Callable,
Generic, Generic,
TypeVar, TypeVar,
overload, cast,
) )
...@@ -22,19 +23,8 @@ try: # Python≥3.10 ...@@ -22,19 +23,8 @@ try: # Python≥3.10
except ImportError: # pragma: no cover - fallback for 3.8/3.9 runtimes except ImportError: # pragma: no cover - fallback for 3.8/3.9 runtimes
import importlib_metadata as md # type: ignore import importlib_metadata as md # type: ignore
__all__ = [ # Legacy exports (keep for one release, then drop)
"Registry", LEGACY_EXPORTS = [
"MetricSpec",
# concrete registries
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
# helper
"freeze_all",
# Legacy compatibility
"DEFAULT_METRIC_REGISTRY", "DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY", "AGGREGATION_REGISTRY",
"register_model", "register_model",
...@@ -59,6 +49,21 @@ __all__ = [ ...@@ -59,6 +49,21 @@ __all__ = [
"FILTER_REGISTRY", "FILTER_REGISTRY",
] ]
__all__ = [
# canonical
"Registry",
"MetricSpec",
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
"freeze_all",
# legacy
*LEGACY_EXPORTS,
]
T = TypeVar("T") T = TypeVar("T")
...@@ -94,25 +99,25 @@ class Registry(Generic[T]): ...@@ -94,25 +99,25 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call) # Registration helpers (decorator or direct call)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@overload # @overload
def register( # def register(
self, # self,
*aliases: str, # *aliases: str,
lazy: None = None, # lazy: None = None,
metadata: dict[str, Any] | None = None, # metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]: # ) -> Callable[[T], T]:
"""Register as decorator: @registry.register("foo").""" # """Register as decorator: @registry.register("foo")."""
... # ...
#
@overload # @overload
def register( # def register(
self, # self,
*aliases: str, # *aliases: str,
lazy: str | md.EntryPoint, # lazy: str | md.EntryPoint,
metadata: dict[str, Any] | None = None, # metadata: dict[str, Any] | None = None,
) -> Callable[[Any], Any]: # ) -> Callable[[Any], Any]:
"""Register lazy: registry.register("foo", lazy="a.b:C")(None).""" # """Register lazy: registry.register("foo", lazy="a.b:C")"""
... # ...
def _resolve_aliases( def _resolve_aliases(
self, target: T | str | md.EntryPoint, aliases: tuple[str, ...] self, target: T | str | md.EntryPoint, aliases: tuple[str, ...]
...@@ -185,47 +190,25 @@ class Registry(Generic[T]): ...@@ -185,47 +190,25 @@ class Registry(Generic[T]):
def register( def register(
self, self,
*aliases: str, *aliases: str,
obj: T | None = None,
lazy: str | md.EntryPoint | None = None, lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]: ):
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``. if obj and lazy:
raise ValueError("pass obj *or* lazy")
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
# ─── direct‑call path with lazy placeholder ───
if lazy is not None:
for alias in self._resolve_aliases(lazy, aliases):
self._check_and_store(alias, lazy, metadata)
return lambda x: x # no‑op decorator for accidental use
# ─── decorator path ─── @functools.wraps(self.register)
def decorator(obj: T) -> T: # type: ignore[valid-type] def _impl(target: T | str | md.EntryPoint):
for alias in self._resolve_aliases(obj, aliases): for a in aliases or (getattr(target, "__name__", str(target)),):
self._check_and_store(alias, obj, metadata) self._check_and_store(a, target, metadata)
return obj return target
return decorator # imperative call → immediately registers and returns the target
if obj is not None or lazy is not None:
return _impl(obj if obj is not None else lazy) # type: ignore[arg-type]
# def register_bulk( # decorator call → return function that will later receive the object
# self, return _impl
# items: dict[str, T | str | md.EntryPoint],
# metadata: dict[str, dict[str, Any]] | None = None,
# ) -> None:
# """Register multiple items at once.
#
# Args:
# items: Dictionary mapping aliases to objects/lazy paths
# metadata: Optional dictionary mapping aliases to metadata
# """
# for alias, target in items.items():
# meta = metadata.get(alias, {}) if metadata else {}
# # For lazy registration, check if it's a string or EntryPoint
# if isinstance(target, (str, md.EntryPoint)):
# self.register(alias, lazy=target, metadata=meta)(None)
# else:
# self.register(alias, metadata=meta)(target)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Lookup & materialisation # Lookup & materialisation
...@@ -241,9 +224,9 @@ class Registry(Generic[T]): ...@@ -241,9 +224,9 @@ class Registry(Generic[T]):
f"Lazy path '{target}' must be in 'module:object' form" f"Lazy path '{target}' must be in 'module:object' form"
) )
module = importlib.import_module(mod) module = importlib.import_module(mod)
return getattr(module, obj_name) return cast(T, getattr(module, obj_name))
if isinstance(target, md.EntryPoint): if isinstance(target, md.EntryPoint):
return target.load() return cast(T, target.load())
return target # concrete already return target # concrete already
def get(self, alias: str) -> T: def get(self, alias: str) -> T:
...@@ -263,14 +246,14 @@ class Registry(Generic[T]): ...@@ -263,14 +246,14 @@ class Registry(Generic[T]):
f"{', '.join(self._objects)}" f"{', '.join(self._objects)}"
) from exc ) from exc
# Double-check after acquiring lock (may have been materialized by another thread) # Double-check after acquiring a lock (may have been materialized by another thread)
if not isinstance(target, (str, md.EntryPoint)): if not isinstance(target, (str, md.EntryPoint)):
return target return target
# Materialize the lazy placeholder # Materialize the lazy placeholder
concrete: T = self._materialise(target) concrete: T = self._materialise(target)
# Swap placeholder with concrete object (with race condition check) # Swap placeholder with a concrete object (with race condition check)
if concrete is not target: if concrete is not target:
# Final check: another thread might have materialized while we were working # Final check: another thread might have materialized while we were working
current = self._objects.get(alias) current = self._objects.get(alias)
...@@ -405,7 +388,7 @@ def default_metrics_for(output_type: str) -> list[str]: ...@@ -405,7 +388,7 @@ def default_metrics_for(output_type: str) -> list[str]:
This walks the metric registry to find metrics that match the output type. This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found. Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
""" """
# First check static defaults # First, check static defaults
if output_type in DEFAULT_METRIC_REGISTRY: if output_type in DEFAULT_METRIC_REGISTRY:
return DEFAULT_METRIC_REGISTRY[output_type] return DEFAULT_METRIC_REGISTRY[output_type]
...@@ -448,7 +431,7 @@ def register_metric(**kwargs): ...@@ -448,7 +431,7 @@ def register_metric(**kwargs):
raise ValueError("metric name is required") raise ValueError("metric name is required")
# Determine aggregation function # Determine aggregation function
aggregate_fn = None aggregate_fn: Callable[[Iterable[Any]], Mapping[str, float]] | None = None
if "aggregation" in kwargs: if "aggregation" in kwargs:
agg_name = kwargs["aggregation"] agg_name = kwargs["aggregation"]
try: try:
...@@ -474,12 +457,12 @@ def register_metric(**kwargs): ...@@ -474,12 +457,12 @@ def register_metric(**kwargs):
requires=kwargs.get("requires"), requires=kwargs.get("requires"),
) )
# Use proper registry API with metadata # Use a proper registry API with metadata
metric_registry.register(metric_name, metadata=kwargs)(spec) metric_registry.register(metric_name, metadata=kwargs)(spec) # type: ignore[misc]
# Also register in higher_is_better registry if specified # Also register in higher_is_better registry if specified
if "higher_is_better" in kwargs: if "higher_is_better" in kwargs:
higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"]) higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"]) # type: ignore[misc]
return fn return fn
...@@ -519,15 +502,17 @@ def get_metric(name: str, hf_evaluate_metric=False): ...@@ -519,15 +502,17 @@ def get_metric(name: str, hf_evaluate_metric=False):
register_metric_aggregation = metric_agg_registry.register register_metric_aggregation = metric_agg_registry.register
def get_metric_aggregation(metric_name: str): def get_metric_aggregation(
metric_name: str,
) -> Callable[[Iterable[Any]], Mapping[str, float]]:
"""Get the aggregation function for a metric.""" """Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation) # First, try to get from the metric registry (for metrics registered with aggregation)
try: try:
metric_spec = metric_registry.get(metric_name) metric_spec = metric_registry.get(metric_name)
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate: if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate return metric_spec.aggregate
except KeyError: except KeyError:
pass # Try next registry pass # Try the next registry
# Fall back to metric_agg_registry (for standalone aggregations) # Fall back to metric_agg_registry (for standalone aggregations)
try: try:
...@@ -535,7 +520,7 @@ def get_metric_aggregation(metric_name: str): ...@@ -535,7 +520,7 @@ def get_metric_aggregation(metric_name: str):
except KeyError: except KeyError:
pass pass
# If not found, raise error # If not found, raise an error
raise KeyError( raise KeyError(
f"Unknown metric aggregation '{metric_name}'. Available: {list(metric_agg_registry)}" f"Unknown metric aggregation '{metric_name}'. Available: {list(metric_agg_registry)}"
) )
...@@ -558,12 +543,12 @@ def register_aggregation(name: str): ...@@ -558,12 +543,12 @@ def register_aggregation(name: str):
) )
def decorate(fn): def decorate(fn):
# Use the canonical registry as single source of truth # Use the canonical registry as a single source of truth
if name in metric_agg_registry: if name in metric_agg_registry:
raise ValueError( raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!" f"aggregation named '{name}' conflicts with existing registered aggregation!"
) )
metric_agg_registry.register(name)(fn) metric_agg_registry.register(name)(fn) # type: ignore[misc]
return fn return fn
return decorate return decorate
......
...@@ -42,7 +42,7 @@ def _register_all_models(): ...@@ -42,7 +42,7 @@ def _register_all_models():
# Only register if not already present (avoids conflicts when modules are imported) # Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry: if name not in model_registry:
# Call register with the lazy parameter, returns a decorator # Call register with the lazy parameter, returns a decorator
model_registry.register(name, lazy=path)(None) model_registry.register(name, lazy=path)
# Call registration on module import # Call registration on module import
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment