Commit 907f5f28 authored by Baber's avatar Baber
Browse files

refactor registry to simplify API and improve clarity

parent e9451269
from __future__ import annotations
import functools
import importlib
import inspect
import threading
......@@ -99,26 +98,6 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
# @overload
# def register(
# self,
# *aliases: str,
# lazy: None = None,
# metadata: dict[str, Any] | None = None,
# ) -> Callable[[T], T]:
# """Register as decorator: @registry.register("foo")."""
# ...
#
# @overload
# def register(
# self,
# *aliases: str,
# lazy: str | md.EntryPoint,
# metadata: dict[str, Any] | None = None,
# ) -> Callable[[Any], Any]:
# """Register lazy: registry.register("foo", lazy="a.b:C")"""
# ...
def _resolve_aliases(
self, target: T | str | md.EntryPoint, aliases: tuple[str, ...]
) -> tuple[str, ...]:
......@@ -188,27 +167,58 @@ class Registry(Generic[T]):
)
def register(
self,
alias: str,
target: T | str | md.EntryPoint,
metadata: dict[str, Any] | None = None,
) -> T | str | md.EntryPoint:
"""Register a target (object or lazy placeholder) under the given alias.
Args:
alias: Name to register under
target: Object to register (can be concrete object or lazy string "module:Class")
metadata: Optional metadata to associate with this registration
Returns:
The target that was registered
Examples:
# Direct registration of concrete object
registry.register("mymodel", MyModelClass)
# Lazy registration with module path
registry.register("mymodel", "mypackage.models:MyModelClass")
"""
self._check_and_store(alias, target, metadata)
return target
def decorator(
self,
*aliases: str,
obj: T | None = None,
lazy: str | md.EntryPoint | None = None,
metadata: dict[str, Any] | None = None,
):
if obj and lazy:
raise ValueError("pass obj *or* lazy")
) -> Callable[[T], T]:
"""Create a decorator for registering objects.
@functools.wraps(self.register)
def _impl(target: T | str | md.EntryPoint):
for a in aliases or (getattr(target, "__name__", str(target)),):
self._check_and_store(a, target, metadata)
return target
Args:
*aliases: Names to register under (if empty, uses object's __name__)
metadata: Optional metadata to associate with this registration
Returns:
Decorator function that registers its target
Example:
@registry.decorator("mymodel", "model-v2")
class MyModel:
pass
"""
# imperative call → immediately registers and returns the target
if obj is not None or lazy is not None:
return _impl(obj if obj is not None else lazy) # type: ignore[arg-type]
def wrapper(obj: T) -> T:
resolved_aliases = aliases or (getattr(obj, "__name__", str(obj)),)
for alias in resolved_aliases:
self.register(alias, obj, metadata)
return obj
# decorator call → return function that will later receive the object
return _impl
return wrapper
# ------------------------------------------------------------------
# Lookup & materialisation
......@@ -411,10 +421,10 @@ AGGREGATION_REGISTRY = metric_agg_registry # The registry itself is dict-like
# Public helper aliases (legacy API)
# ────────────────────────────────────────────────────────────────────────
register_model = model_registry.register
register_model = model_registry.decorator
get_model = model_registry.get
register_task = task_registry.register
register_task = task_registry.decorator
get_task = task_registry.get
......@@ -458,11 +468,11 @@ def register_metric(**kwargs):
)
# Use a proper registry API with metadata
metric_registry.register(metric_name, metadata=kwargs)(spec) # type: ignore[misc]
metric_registry.register(metric_name, spec, metadata=kwargs)
# Also register in higher_is_better registry if specified
if "higher_is_better" in kwargs:
higher_is_better_registry.register(metric_name)(kwargs["higher_is_better"]) # type: ignore[misc]
higher_is_better_registry.register(metric_name, kwargs["higher_is_better"])
return fn
......@@ -499,7 +509,7 @@ def get_metric(name: str, hf_evaluate_metric=False):
return None
register_metric_aggregation = metric_agg_registry.register
register_metric_aggregation = metric_agg_registry.decorator
def get_metric_aggregation(
......@@ -526,10 +536,10 @@ def get_metric_aggregation(
)
register_higher_is_better = higher_is_better_registry.register
register_higher_is_better = higher_is_better_registry.decorator
is_higher_better = higher_is_better_registry.get
register_filter = filter_registry.register
register_filter = filter_registry.decorator
get_filter = filter_registry.get
......@@ -548,7 +558,7 @@ def register_aggregation(name: str):
raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
metric_agg_registry.register(name)(fn) # type: ignore[misc]
metric_agg_registry.register(name, fn)
return fn
return decorate
......
......@@ -41,8 +41,8 @@ def _register_all_models():
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Call register with the lazy parameter, returns a decorator
model_registry.register(name, lazy=path)
# Register the lazy placeholder directly
model_registry.register(name, path)
# Call registration on module import
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment