Commit 907f5f28 authored by Baber's avatar Baber
Browse files

refactor registry to simplify API and improve clarity

parent e9451269
from __future__ import annotations from __future__ import annotations
import functools
import importlib import importlib
import inspect import inspect
import threading import threading
...@@ -99,26 +98,6 @@ class Registry(Generic[T]): ...@@ -99,26 +98,6 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call) # Registration helpers (decorator or direct call)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# @overload
# def register(
# self,
# *aliases: str,
# lazy: None = None,
# metadata: dict[str, Any] | None = None,
# ) -> Callable[[T], T]:
# """Register as decorator: @registry.register("foo")."""
# ...
#
# @overload
# def register(
# self,
# *aliases: str,
# lazy: str | md.EntryPoint,
# metadata: dict[str, Any] | None = None,
# ) -> Callable[[Any], Any]:
# """Register lazy: registry.register("foo", lazy="a.b:C")"""
# ...
def _resolve_aliases( def _resolve_aliases(
self, target: T | str | md.EntryPoint, aliases: tuple[str, ...] self, target: T | str | md.EntryPoint, aliases: tuple[str, ...]
) -> tuple[str, ...]: ) -> tuple[str, ...]:
...@@ -189,26 +168,57 @@ class Registry(Generic[T]): ...@@ -189,26 +168,57 @@ class Registry(Generic[T]):
def register( def register(
self, self,
*aliases: str, alias: str,
obj: T | None = None, target: T | str | md.EntryPoint,
lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
): ) -> T | str | md.EntryPoint:
if obj and lazy: """Register a target (object or lazy placeholder) under the given alias.
raise ValueError("pass obj *or* lazy")
@functools.wraps(self.register) Args:
def _impl(target: T | str | md.EntryPoint): alias: Name to register under
for a in aliases or (getattr(target, "__name__", str(target)),): target: Object to register (can be concrete object or lazy string "module:Class")
self._check_and_store(a, target, metadata) metadata: Optional metadata to associate with this registration
Returns:
The target that was registered
Examples:
# Direct registration of concrete object
registry.register("mymodel", MyModelClass)
# Lazy registration with module path
registry.register("mymodel", "mypackage.models:MyModelClass")
"""
self._check_and_store(alias, target, metadata)
return target return target
# imperative call → immediately registers and returns the target def decorator(
if obj is not None or lazy is not None: self,
return _impl(obj if obj is not None else lazy) # type: ignore[arg-type] *aliases: str,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""Create a decorator for registering objects.
Args:
*aliases: Names to register under (if empty, uses object's __name__)
metadata: Optional metadata to associate with this registration
Returns:
Decorator function that registers its target
Example:
@registry.decorator("mymodel", "model-v2")
class MyModel:
pass
"""
def wrapper(obj: T) -> T:
resolved_aliases = aliases or (getattr(obj, "__name__", str(obj)),)
for alias in resolved_aliases:
self.register(alias, obj, metadata)
return obj
# decorator call → return function that will later receive the object return wrapper
return _impl
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Lookup & materialisation # Lookup & materialisation
...@@ -411,10 +421,10 @@ AGGREGATION_REGISTRY = metric_agg_registry # The registry itself is dict-like ...@@ -411,10 +421,10 @@ AGGREGATION_REGISTRY = metric_agg_registry # The registry itself is dict-like
# Public helper aliases (legacy API) # Public helper aliases (legacy API)
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
register_model = model_registry.register register_model = model_registry.decorator
get_model = model_registry.get get_model = model_registry.get
register_task = task_registry.register register_task = task_registry.decorator
get_task = task_registry.get get_task = task_registry.get
...@@ -458,11 +468,11 @@ def register_metric(**kwargs): ...@@ -458,11 +468,11 @@ def register_metric(**kwargs):
) )
# Use a proper registry API with metadata # Use a proper registry API with metadata
metric_registry.register(metric_name, metadata=kwargs)(spec) # type: ignore[misc] metric_registry.register(metric_name, spec, metadata=kwargs)
# Also register in higher_is_better registry if specified # Also register in higher_is_better registry if specified
if "higher_is_better" in kwargs: if "higher_is_better" in kwargs:
higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"]) # type: ignore[misc] higher_is_better_registry.register(metric_name, kwargs["higher_is_better"])
return fn return fn
...@@ -499,7 +509,7 @@ def get_metric(name: str, hf_evaluate_metric=False): ...@@ -499,7 +509,7 @@ def get_metric(name: str, hf_evaluate_metric=False):
return None return None
register_metric_aggregation = metric_agg_registry.register register_metric_aggregation = metric_agg_registry.decorator
def get_metric_aggregation( def get_metric_aggregation(
...@@ -526,10 +536,10 @@ def get_metric_aggregation( ...@@ -526,10 +536,10 @@ def get_metric_aggregation(
) )
register_higher_is_better = higher_is_better_registry.register register_higher_is_better = higher_is_better_registry.decorator
is_higher_better = higher_is_better_registry.get is_higher_better = higher_is_better_registry.get
register_filter = filter_registry.register register_filter = filter_registry.decorator
get_filter = filter_registry.get get_filter = filter_registry.get
...@@ -548,7 +558,7 @@ def register_aggregation(name: str): ...@@ -548,7 +558,7 @@ def register_aggregation(name: str):
raise ValueError( raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!" f"aggregation named '{name}' conflicts with existing registered aggregation!"
) )
metric_agg_registry.register(name)(fn) # type: ignore[misc] metric_agg_registry.register(name, fn)
return fn return fn
return decorate return decorate
......
...@@ -41,8 +41,8 @@ def _register_all_models(): ...@@ -41,8 +41,8 @@ def _register_all_models():
for name, path in MODEL_MAPPING.items(): for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported) # Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry: if name not in model_registry:
# Call register with the lazy parameter, returns a decorator # Register the lazy placeholder directly
model_registry.register(name, lazy=path) model_registry.register(name, path)
# Call registration on module import # Call registration on module import
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment