Commit 9af24b7e authored by Baber's avatar Baber
Browse files

refactor registry for simplicity and improved maintainability

parent 907f5f28
...@@ -3,23 +3,16 @@ from __future__ import annotations ...@@ -3,23 +3,16 @@ from __future__ import annotations
import importlib import importlib
import inspect import inspect
import threading import threading
import warnings from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from types import MappingProxyType from types import MappingProxyType
from typing import ( from typing import Any, Callable, Generic, Type, TypeVar, Union, cast
Any,
Callable,
Generic,
TypeVar,
cast,
)
try: # Python≥3.10 try:
import importlib.metadata as md import importlib.metadata as md # Python ≥3.10
except ImportError: # pragma: no cover - fallback for 3.8/3.9 runtimes except ImportError: # pragma: no cover fallback for 3.8/3.9
import importlib_metadata as md # type: ignore import importlib_metadata as md # type: ignore
# Legacy exports (keep for one release, then drop) # Legacy exports (keep for one release, then drop)
...@@ -64,6 +57,7 @@ __all__ = [ ...@@ -64,6 +57,7 @@ __all__ = [
] ]
T = TypeVar("T") T = TypeVar("T")
Placeholder = Union[str, md.EntryPoint] # light‑weight lazy token
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
...@@ -72,533 +66,264 @@ T = TypeVar("T") ...@@ -72,533 +66,264 @@ T = TypeVar("T")
class Registry(Generic[T]): class Registry(Generic[T]):
"""Name -> object mapping with decorator helpers and **lazy import** support.""" """Name → object registry with optional lazy placeholders."""
#: The underlying mutable mapping (might turn into MappingProxy on freeze)
_objects: MutableMapping[str, T | str | md.EntryPoint]
def __init__( def __init__(
self, self,
name: str, name: str,
*, *,
base_cls: type[T] | None = None, base_cls: Union[Type[T], None] = None,
store: MutableMapping[str, T | str | md.EntryPoint] | None = None,
validator: Callable[[T], bool] | None = None,
) -> None: ) -> None:
self._name: str = name self._name = name
self._base_cls: type[T] | None = base_cls self._base_cls = base_cls
self._objects = store if store is not None else {} self._objs: dict[str, Union[T, Placeholder]] = {}
self._metadata: dict[ self._meta: dict[str, dict[str, Any]] = {}
str, dict[str, Any]
] = {} # Store metadata for each registered item
self._validator = validator # Custom validation function
self._lock = threading.RLock() self._lock = threading.RLock()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Registration helpers (decorator or direct call) # Registration (decorator or direct call)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _resolve_aliases( def register(
self, target: T | str | md.EntryPoint, aliases: tuple[str, ...]
) -> tuple[str, ...]:
"""Resolve aliases for registration."""
if not aliases:
return (getattr(target, "__name__", str(target)),)
return aliases
def _check_and_store(
self, self,
alias: str, *aliases: str,
target: T | str | md.EntryPoint, lazy: Union[T, Placeholder, None] = None,
metadata: dict[str, Any] | None, metadata: dict[str, Any] | None = None,
) -> None: ) -> Callable[[T], T]:
"""Check constraints and store the target with optional metadata. """``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``."""
Collision policy: def _store(alias: str, target: Union[T, Placeholder]) -> None:
1. If alias doesn't exist → store it current = self._objs.get(alias)
2. If identical value → silently succeed (idempotent) # ─── collision handling ────────────────────────────────────
3. If lazy placeholder + matching concrete class → replace with concrete if current is not None and current != target:
4. Otherwise → raise ValueError # allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type):
Type checking: mod, _, cls = current.partition(":")
- Eager for concrete classes at registration time if current == f"{target.__module__}:{target.__name__}":
- Deferred for lazy placeholders until materialization self._objs[alias] = target
""" self._meta[alias] = metadata or {}
with self._lock: return
# Case 1: New alias raise ValueError(
if alias not in self._objects: f"{self._name!r} alias '{alias}' already registered (" # noqa: B950
# Type check concrete classes before storing f"existing={current}, new={target})"
)
# ─── type check for concrete classes ───────────────────────
if self._base_cls is not None and isinstance(target, type): if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type] if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError( raise TypeError(
f"{target} must inherit from {self._base_cls} " f"{target} must inherit from {self._base_cls} to be a {self._name}"
f"to be registered as a {self._name}"
) )
self._objects[alias] = target self._objs[alias] = target
if metadata: if metadata:
self._metadata[alias] = metadata self._meta[alias] = metadata
return
existing = self._objects[alias]
# Case 2: Identical value - idempotent
if existing == target:
return
# Case 3: Lazy placeholder being replaced by its concrete class
if isinstance(existing, str) and isinstance(target, type):
mod_path, _, cls_name = existing.partition(":")
if (
cls_name
and hasattr(target, "__module__")
and hasattr(target, "__name__")
):
expected_path = f"{target.__module__}:{target.__name__}"
if existing == expected_path:
self._objects[alias] = target
if metadata:
self._metadata[alias] = metadata
return
# Case 4: Collision - different values
raise ValueError(
f"{self._name!r} '{alias}' already registered "
f"(existing: {existing}, new: {target})"
)
def register(
self,
alias: str,
target: T | str | md.EntryPoint,
metadata: dict[str, Any] | None = None,
) -> T | str | md.EntryPoint:
"""Register a target (object or lazy placeholder) under the given alias.
Args:
alias: Name to register under
target: Object to register (can be concrete object or lazy string "module:Class")
metadata: Optional metadata to associate with this registration
Returns:
The target that was registered
Examples:
# Direct registration of concrete object
registry.register("mymodel", MyModelClass)
# Lazy registration with module path
registry.register("mymodel", "mypackage.models:MyModelClass")
"""
self._check_and_store(alias, target, metadata)
return target
def decorator(
self,
*aliases: str,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]:
"""Create a decorator for registering objects.
Args:
*aliases: Names to register under (if empty, uses object's __name__)
metadata: Optional metadata to associate with this registration
Returns: def decorator(obj: T) -> T: # type: ignore[valid-type]
Decorator function that registers its target names = aliases or (getattr(obj, "__name__", str(obj)),)
with self._lock:
Example: for name in names:
@registry.decorator("mymodel", "model-v2") _store(name, obj)
class MyModel:
pass
"""
def wrapper(obj: T) -> T:
resolved_aliases = aliases or (getattr(obj, "__name__", str(obj)),)
for alias in resolved_aliases:
self.register(alias, obj, metadata)
return obj return obj
return wrapper # Direct call with *lazy* placeholder
if lazy is not None:
if len(aliases) != 1:
raise ValueError("Exactly one alias required when using 'lazy='")
with self._lock:
_store(aliases[0], lazy) # type: ignore[arg-type]
# return no‑op decorator for accidental use
return lambda x: x # type: ignore[return-value]
return decorator
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Lookup & materialisation # Lookup & materialisation
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@lru_cache(maxsize=256) # Bounded cache to prevent memory growth @lru_cache(maxsize=256)
def _materialise(self, target: T | str | md.EntryPoint) -> T: def _materialise(self, ph: Placeholder) -> T:
"""Import *target* if it is a dotted‑path string or EntryPoint.""" if isinstance(ph, str):
if isinstance(target, str): mod, _, attr = ph.partition(":")
mod, _, obj_name = target.partition(":") if not attr:
if not _: raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
raise ValueError( return cast(T, getattr(importlib.import_module(mod), attr))
f"Lazy path '{target}' must be in 'module:object' form" return cast(T, ph.load())
)
module = importlib.import_module(mod)
return cast(T, getattr(module, obj_name))
if isinstance(target, md.EntryPoint):
return cast(T, target.load())
return target # concrete already
def get(self, alias: str) -> T: def get(self, alias: str) -> T:
# Fast path: check if already materialized without lock
target = self._objects.get(alias)
if target is not None and not isinstance(target, (str, md.EntryPoint)):
# Already materialized and validated, return immediately
return target
# Slow path: acquire lock for materialization
with self._lock:
try: try:
target = self._objects[alias] target = self._objs[alias]
except KeyError as exc: except KeyError as exc:
raise KeyError( raise KeyError(
f"Unknown {self._name} '{alias}'. Available: " f"Unknown {self._name} '{alias}'. Available: {', '.join(self._objs)}"
f"{', '.join(self._objects)}"
) from exc ) from exc
# Double-check after acquiring a lock (may have been materialized by another thread) if isinstance(target, (str, md.EntryPoint)):
if not isinstance(target, (str, md.EntryPoint)): with self._lock:
return target # Re‑check under lock (another thread might have resolved it)
fresh = self._objs[alias]
# Materialize the lazy placeholder if isinstance(fresh, (str, md.EntryPoint)):
concrete: T = self._materialise(target) concrete = self._materialise(fresh)
self._objs[alias] = concrete
# Swap placeholder with a concrete object (with race condition check)
if concrete is not target:
# Final check: another thread might have materialized while we were working
current = self._objects.get(alias)
if isinstance(current, (str, md.EntryPoint)):
# Still a placeholder, safe to replace
self._objects[alias] = concrete
else: else:
# Another thread already materialized it, use their result concrete = fresh # another thread did the job
concrete = current # type: ignore[assignment] target = concrete
# Late type check (for placeholders) # Late type/validator checks
if self._base_cls is not None and not issubclass(concrete, self._base_cls): # type: ignore[arg-type] if self._base_cls is not None and not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError( raise TypeError(
f"{concrete} does not inherit from {self._base_cls} " f"{target} does not inherit from {self._base_cls} (alias '{alias}')"
f"(registered under alias '{alias}')"
)
# Custom validation - run on materialization
if self._validator and not self._validator(concrete):
raise ValueError(
f"{concrete} failed custom validation for {self._name} registry "
f"(registered under alias '{alias}')"
) )
return target
return concrete # ------------------------------------------------------------------
# Mapping helpers
# Mapping / dunder helpers ------------------------------------------------- # ------------------------------------------------------------------
def __getitem__(self, alias: str) -> T: # noqa def __getitem__(self, alias: str) -> T: # noqa: DunderImplemented
return self.get(alias) return self.get(alias)
def __iter__(self): # noqa def __iter__(self): # noqa: DunderImplemented
return iter(self._objects) return iter(self._objs)
def __len__(self): # noqa: DunderImplemented
return len(self._objs)
def __len__(self) -> int: # noqa def items(self): # noqa: DunderImplemented
return len(self._objects) return self._objs.items()
def items(self): # noqa # ------------------------------------------------------------------
return self._objects.items() # Utilities
# ------------------------------------------------------------------
# Introspection ----------------------------------------------------------- def metadata(self, alias: str) -> Union[Mapping[str, Any], None]:
return self._meta.get(alias)
def origin(self, alias: str) -> str | None: def origin(self, alias: str) -> Union[str, None]:
obj = self._objects.get(alias) obj = self._objs.get(alias)
if isinstance(obj, (str, md.EntryPoint)):
return None
try: try:
if isinstance(obj, str) or isinstance(obj, md.EntryPoint): path = inspect.getfile(obj) # type: ignore[arg-type]
return None # placeholder - unknown until imported
file = inspect.getfile(obj) # type: ignore[arg-type]
line = inspect.getsourcelines(obj)[1] # type: ignore[arg-type] line = inspect.getsourcelines(obj)[1] # type: ignore[arg-type]
return f"{file}:{line}" return f"{path}:{line}"
except ( except Exception: # pragma: no cover – best‑effort only
TypeError,
OSError,
AttributeError,
): # pragma: no cover - best-effort only
# TypeError: object not suitable for inspect
# OSError: file not found or accessible
# AttributeError: object lacks expected attributes
return None return None
def get_metadata(self, alias: str) -> dict[str, Any] | None:
"""Get metadata for a registered item."""
with self._lock:
return self._metadata.get(alias)
# Mutability --------------------------------------------------------------
def freeze(self): def freeze(self):
"""Make the registry *names* immutable (materialisation still works)."""
with self._lock: with self._lock:
if isinstance(self._objects, MappingProxyType): self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment]
return # already frozen self._meta = MappingProxyType(dict(self._meta)) # type: ignore[assignment]
self._objects = MappingProxyType(dict(self._objects)) # type: ignore[assignment]
def clear(self): # Test helper -------------------------------------------------------------
"""Clear the registry (useful for tests). Cannot be called on frozen registries."""
with self._lock: def _clear(self): # pragma: no cover
if isinstance(self._objects, MappingProxyType): """Erase registry (for isolated tests)."""
raise RuntimeError("Cannot clear a frozen registry") self._objs.clear()
self._objects.clear() self._meta.clear()
self._metadata.clear() self._materialise.cache_clear()
self._materialise.cache_clear() # type: ignore[attr-defined]
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
# Structured objects stored in registries # Structured object for metrics
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
@dataclass(frozen=True) @dataclass(frozen=True)
class MetricSpec: class MetricSpec:
"""Bundle compute fn, aggregator, and *higher‑is‑better* flag."""
compute: Callable[[Any, Any], Any] compute: Callable[[Any, Any], Any]
aggregate: Callable[[Iterable[Any]], Mapping[str, float]] aggregate: Callable[[Iterable[Any]], float]
higher_is_better: bool = True higher_is_better: bool = True
output_type: str | None = None # e.g., "probability", "string", "numeric" output_type: Union[str, None] = None
requires: list[str] | None = None # Dependencies on other metrics/data requires: Union[list[str], None] = None
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
# Concrete registries used by lm_eval # Canonical registries
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
from lm_eval.api.model import LM # noqa: E402 from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[LM] = Registry("model", base_cls=LM) model_registry: Registry[type[LM]] = Registry("model", base_cls=LM)
task_registry: Registry[Callable[..., Any]] = Registry("task") task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric") metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], Mapping[str, float]]] = ( metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
Registry("metric aggregation") "metric aggregation"
) )
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag") higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[Callable] = Registry("filter") filter_registry: Registry[Callable] = Registry("filter")
# Default metric registry for output types # Public helper aliases ------------------------------------------------------
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
}
def default_metrics_for(output_type: str) -> list[str]:
"""Get default metrics for a given output type dynamically.
This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
"""
# First, check static defaults
if output_type in DEFAULT_METRIC_REGISTRY:
return DEFAULT_METRIC_REGISTRY[output_type]
# Walk metric registry for matching output types
matching_metrics = []
for name, metric_spec in metric_registry.items():
if (
isinstance(metric_spec, MetricSpec)
and metric_spec.output_type == output_type
):
matching_metrics.append(name)
return matching_metrics if matching_metrics else []
register_model = model_registry.register
# Aggregation registry - alias to the canonical registry for backward compatibility
AGGREGATION_REGISTRY = metric_agg_registry # The registry itself is dict-like
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
# ────────────────────────────────────────────────────────────────────────
register_model = model_registry.decorator
get_model = model_registry.get get_model = model_registry.get
register_task = task_registry.decorator register_task = task_registry.register
get_task = task_registry.get get_task = task_registry.get
register_filter = filter_registry.register
get_filter = filter_registry.get
# Special handling for metric registration which uses different API # Metric helpers need thin wrappers to build MetricSpec ----------------------
def register_metric(**kwargs):
"""Register a metric with metadata.
Compatible with old registry API that used keyword arguments.
"""
def decorate(fn):
metric_name = kwargs.get("metric")
if not metric_name:
raise ValueError("metric name is required")
# Determine aggregation function
aggregate_fn: Callable[[Iterable[Any]], Mapping[str, float]] | None = None
if "aggregation" in kwargs:
agg_name = kwargs["aggregation"]
try:
aggregate_fn = metric_agg_registry.get(agg_name)
except KeyError:
raise ValueError(f"Unknown aggregation: {agg_name}")
else:
# No aggregation specified - use a function that raises NotImplementedError
def not_implemented_agg(values):
raise NotImplementedError(
f"No aggregation function specified for metric '{metric_name}'. "
"Please specify an 'aggregation' parameter."
)
aggregate_fn = not_implemented_agg def register_metric(**kw):
name = kw["metric"]
# Create MetricSpec with the function and metadata def deco(fn):
spec = MetricSpec( spec = MetricSpec(
compute=fn, compute=fn,
aggregate=aggregate_fn, aggregate=(
higher_is_better=kwargs.get("higher_is_better", True), metric_agg_registry.get(kw["aggregation"])
output_type=kwargs.get("output_type"), if "aggregation" in kw
requires=kwargs.get("requires"), else lambda _: {}
),
higher_is_better=kw.get("higher_is_better", True),
output_type=kw.get("output_type"),
requires=kw.get("requires"),
) )
metric_registry.register(name, lazy=spec, metadata=kw)
# Use a proper registry API with metadata higher_is_better_registry.register(name, lazy=spec.higher_is_better)
metric_registry.register(metric_name, spec, metadata=kwargs)
# Also register in higher_is_better registry if specified
if "higher_is_better" in kwargs:
higher_is_better_registry.register(metric_name, kwargs["higher_is_better"])
return fn return fn
return decorate return deco
def get_metric(name: str, hf_evaluate_metric=False): def get_metric(name, hf_evaluate_metric=False):
"""Get a metric by name, with fallback to HF evaluate."""
if not hf_evaluate_metric:
try: try:
spec = metric_registry.get(name) spec = metric_registry.get(name)
if isinstance(spec, MetricSpec): return spec.compute # type: ignore[attr-defined]
return spec.compute
return spec
except KeyError: except KeyError:
if not hf_evaluate_metric:
import logging import logging
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..." f"Metric '{name}' not in registry; trying HF evaluate"
) )
# Fallback to HF evaluate
try: try:
import evaluate as hf_evaluate import evaluate as hf
metric_object = hf_evaluate.load(name) return hf.load(name).compute # type: ignore[attr-defined]
return metric_object.compute
except Exception: except Exception:
import logging raise KeyError(f"Metric '{name}' not found anywhere")
logging.getLogger(__name__).error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
)
return None
register_metric_aggregation = metric_agg_registry.decorator
def get_metric_aggregation(
metric_name: str,
) -> Callable[[Iterable[Any]], Mapping[str, float]]:
"""Get the aggregation function for a metric."""
# First, try to get from the metric registry (for metrics registered with aggregation)
try:
metric_spec = metric_registry.get(metric_name)
if isinstance(metric_spec, MetricSpec) and metric_spec.aggregate:
return metric_spec.aggregate
except KeyError:
pass # Try the next registry
# Fall back to metric_agg_registry (for standalone aggregations)
try:
return metric_agg_registry.get(metric_name)
except KeyError:
pass
# If not found, raise an error
raise KeyError(
f"Unknown metric aggregation '{metric_name}'. Available: {list(metric_agg_registry)}"
)
register_metric_aggregation = metric_agg_registry.register
get_metric_aggregation = metric_agg_registry.get
register_higher_is_better = higher_is_better_registry.decorator register_higher_is_better = higher_is_better_registry.register
is_higher_better = higher_is_better_registry.get is_higher_better = higher_is_better_registry.get
register_filter = filter_registry.decorator # Legacy compatibility
get_filter = filter_registry.get register_aggregation = metric_agg_registry.register
get_aggregation = metric_agg_registry.get
DEFAULT_METRIC_REGISTRY = metric_registry
AGGREGATION_REGISTRY = metric_agg_registry
# Special handling for AGGREGATION_REGISTRY which works differently # Convenience ----------------------------------------------------------------
def register_aggregation(name: str):
"""@deprecated Use metric_agg_registry.register() instead."""
warnings.warn(
"register_aggregation() is deprecated. Use metric_agg_registry.register() instead.",
DeprecationWarning,
stacklevel=2,
)
def decorate(fn):
# Use the canonical registry as a single source of truth
if name in metric_agg_registry:
raise ValueError(
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
metric_agg_registry.register(name, fn)
return fn
return decorate
def get_aggregation(name: str) -> Callable[[Iterable[Any]], Mapping[str, float]] | None:
"""@deprecated Use metric_agg_registry.get() instead."""
try:
# Use the canonical registry
return metric_agg_registry.get(name)
except KeyError:
import logging
logging.getLogger(__name__).warning( def freeze_all():
f"{name} not a registered aggregation metric!" for r in (
)
return None
# ────────────────────────────────────────────────────────────────────────
# Optional PyPI entry‑point discovery - uncomment if desired
# ────────────────────────────────────────────────────────────────────────
# for _group, _reg in {
# "lm_eval.models": model_registry,
# "lm_eval.tasks": task_registry,
# "lm_eval.metrics": metric_registry,
# }.items():
# for _ep in md.entry_points(group=_group):
# _reg.register(_ep.name, lazy=_ep)
# ────────────────────────────────────────────────────────────────────────
# Convenience
# ────────────────────────────────────────────────────────────────────────
def freeze_all() -> None: # pragma: no cover
"""Freeze every global registry (idempotent)."""
for _reg in (
model_registry, model_registry,
task_registry, task_registry,
metric_registry, metric_registry,
...@@ -606,24 +331,14 @@ def freeze_all() -> None: # pragma: no cover ...@@ -606,24 +331,14 @@ def freeze_all() -> None: # pragma: no cover
higher_is_better_registry, higher_is_better_registry,
filter_registry, filter_registry,
): ):
_reg.freeze() r.freeze()
# ──────────────────────────────────────────────────────────────────────── # Backwards‑compat read‑only aliases ----------------------------------------
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
# These are direct aliases to the registries themselves, which already implement MODEL_REGISTRY = model_registry # type: ignore
# the Mapping protocol and provide read-only access to users (since _objects is private). TASK_REGISTRY = task_registry # type: ignore
# This ensures they always reflect the current state of the registries, including METRIC_REGISTRY = metric_registry # type: ignore
# items registered after module import. METRIC_AGGREGATION_REGISTRY = metric_agg_registry # type: ignore
# HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry # type: ignore
# Note: We use type: ignore because Registry doesn't formally inherit from Mapping, FILTER_REGISTRY = filter_registry # type: ignore
# but it implements all required methods (__getitem__, __iter__, __len__, items)
MODEL_REGISTRY: Mapping[str, LM] = model_registry # type: ignore[assignment]
TASK_REGISTRY: Mapping[str, Callable[..., Any]] = task_registry # type: ignore[assignment]
METRIC_REGISTRY: Mapping[str, MetricSpec] = metric_registry # type: ignore[assignment]
METRIC_AGGREGATION_REGISTRY: Mapping[str, Callable] = metric_agg_registry # type: ignore[assignment]
HIGHER_IS_BETTER_REGISTRY: Mapping[str, bool] = higher_is_better_registry # type: ignore[assignment]
FILTER_REGISTRY: Mapping[str, Callable] = filter_registry # type: ignore[assignment]
...@@ -41,8 +41,8 @@ def _register_all_models(): ...@@ -41,8 +41,8 @@ def _register_all_models():
for name, path in MODEL_MAPPING.items(): for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported) # Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry: if name not in model_registry:
# Register the lazy placeholder directly # Register the lazy placeholder using lazy parameter
model_registry.register(name, path) model_registry.register(name, lazy=path)
# Call registration on module import # Call registration on module import
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment