Commit 9ec26934 authored by Baber's avatar Baber
Browse files

add docs

parent 2b32f7be
"""Registry system for lm_eval components.
This module provides a centralized registration system for models, tasks, metrics,
filters, and other components in the lm_eval framework. The registry supports:
- Lazy loading with placeholders to improve startup time
- Type checking and validation
- Thread-safe registration and lookup
- Plugin discovery via entry points
- Backwards compatibility with legacy registration patterns
## Usage Examples
### Registering a Model
```python
from lm_eval.api.registry import register_model
from lm_eval.api.model import LM
@register_model("my-model")
class MyModel(LM):
def __init__(self, **kwargs):
...
```
### Registering a Metric
```python
from lm_eval.api.registry import register_metric
@register_metric(
metric="my_accuracy",
aggregation="mean",
higher_is_better=True
)
def my_accuracy_fn(items):
...
```
### Registering with Lazy Loading
```python
# Register without importing the actual implementation
model_registry.register("lazy-model", lazy="my_package.models:LazyModel")
```
### Looking up Components
```python
from lm_eval.api.registry import get_model, get_metric
# Get a model class
model_cls = get_model("gpt-j")
model = model_cls(**config)
# Get a metric function
metric_fn = get_metric("accuracy")
```
"""
from __future__ import annotations from __future__ import annotations
import importlib import importlib
...@@ -9,6 +65,8 @@ from functools import lru_cache ...@@ -9,6 +65,8 @@ from functools import lru_cache
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable, Generic, TypeVar, Union, cast from typing import Any, Callable, Generic, TypeVar, Union, cast
from lm_eval.api.filter import Filter
try: try:
import importlib.metadata as md # Python ≥3.10 import importlib.metadata as md # Python ≥3.10
...@@ -55,12 +113,7 @@ __all__ = [ ...@@ -55,12 +113,7 @@ __all__ = [
] # type: ignore ] # type: ignore
T = TypeVar("T") T = TypeVar("T")
Placeholder = Union[str, md.EntryPoint] # light‑weight lazy token Placeholder = Union[str, md.EntryPoint]
# ────────────────────────────────────────────────────────────────────────
# Module-level cache for materializing placeholders (prevents memory leak)
# ────────────────────────────────────────────────────────────────────────
@lru_cache(maxsize=16) @lru_cache(maxsize=16)
...@@ -68,6 +121,17 @@ def _materialise_placeholder(ph: Placeholder) -> Any: ...@@ -68,6 +121,17 @@ def _materialise_placeholder(ph: Placeholder) -> Any:
"""Materialize a lazy placeholder into the actual object. """Materialize a lazy placeholder into the actual object.
This is at module level to avoid memory leaks from lru_cache on instance methods. This is at module level to avoid memory leaks from lru_cache on instance methods.
Args:
ph: Either a string path "module:object" or an EntryPoint instance
Returns:
The loaded object
Raises:
ValueError: If the string format is invalid
ImportError: If the module cannot be imported
AttributeError: If the object doesn't exist in the module
""" """
if isinstance(ph, str): if isinstance(ph, str):
mod, _, attr = ph.partition(":") mod, _, attr = ph.partition(":")
...@@ -77,21 +141,39 @@ def _materialise_placeholder(ph: Placeholder) -> Any: ...@@ -77,21 +141,39 @@ def _materialise_placeholder(ph: Placeholder) -> Any:
return ph.load() return ph.load()
# ──────────────────────────────────────────────────────────────────────── # Metric-specific metadata storage --------------------------------------------
# Metric-specific metadata storage
# ────────────────────────────────────────────────────────────────────────
_metric_meta: dict[str, dict[str, Any]] = {} _metric_meta: dict[str, dict[str, Any]] = {}
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────
class Registry(Generic[T]): class Registry(Generic[T]):
"""Name → object registry with optional lazy placeholders.""" """A thread-safe registry for named objects with lazy loading support.
The Registry provides a central location for registering and retrieving
components by name. It supports:
- Direct registration of objects
- Lazy registration with placeholders (strings or entry points)
- Type checking against a base class
- Thread-safe operations
- Freezing to prevent further modifications
Example:
>>> from lm_eval.api.model import LM
>>> registry = Registry("models", base_cls=LM)
>>>
>>> # Direct registration
>>> @registry.register("my-model")
>>> class MyModel(LM):
... pass
>>>
>>> # Lazy registration
>>> registry.register("lazy-model", lazy="mypackage:LazyModel")
>>>
>>> # Retrieval (triggers lazy loading if needed)
>>> model_cls = registry.get("my-model")
>>> model = model_cls()
"""
def __init__( def __init__(
self, self,
...@@ -99,25 +181,52 @@ class Registry(Generic[T]): ...@@ -99,25 +181,52 @@ class Registry(Generic[T]):
*, *,
base_cls: type[T] | None = None, base_cls: type[T] | None = None,
) -> None: ) -> None:
"""Initialize a new registry.
Args:
name: Human-readable name for error messages (e.g., "model", "metric")
base_cls: Optional base class that all registered objects must inherit from
"""
self._name = name self._name = name
self._base_cls = base_cls self._base_cls = base_cls
self._objs: dict[str, T | Placeholder] = {} self._objs: dict[str, T | Placeholder] = {}
self._lock = threading.RLock() self._lock = threading.RLock()
# ------------------------------------------------------------------ # Registration (decorator or direct call) --------------------------------------
# Registration (decorator or direct call)
# ------------------------------------------------------------------
def register( def register(
self, self,
*aliases: str, *aliases: str,
lazy: T | Placeholder | None = None, lazy: T | Placeholder | None = None,
) -> Callable[[T], T]: ) -> Callable[[T], T]:
"""``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``.""" """Register an object under one or more aliases.
Can be used as a decorator or called directly for lazy registration.
Args:
*aliases: Names to register the object under. If empty, uses object's __name__
lazy: For direct calls only - a placeholder string "module:object" or EntryPoint
Returns:
Decorator function (or no-op if lazy registration)
Examples:
>>> # As decorator
>>> @model_registry.register("name1", "name2")
>>> class MyModel(LM):
... pass
>>>
>>> # Direct lazy registration
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises:
ValueError: If alias already registered with different target
TypeError: If object doesn't inherit from base_cls (when specified)
"""
def _store(alias: str, target: T | Placeholder) -> None: def _store(alias: str, target: T | Placeholder) -> None:
current = self._objs.get(alias) current = self._objs.get(alias)
# ─── collision handling ──────────────────────────────────── # collision handling ------------------------------------------
if current is not None and current != target: if current is not None and current != target:
# allow placeholder → real object upgrade # allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type): if isinstance(current, str) and isinstance(target, type):
...@@ -129,7 +238,7 @@ class Registry(Generic[T]): ...@@ -129,7 +238,7 @@ class Registry(Generic[T]):
f"{self._name!r} alias '{alias}' already registered (" f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})" f"existing={current}, new={target})"
) )
# ─── type check for concrete classes ─────────────────────── # type check for concrete classes ----------------------------------------------
if self._base_cls is not None and isinstance(target, type): if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type] if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError( raise TypeError(
...@@ -155,15 +264,36 @@ class Registry(Generic[T]): ...@@ -155,15 +264,36 @@ class Registry(Generic[T]):
return decorator return decorator
# ------------------------------------------------------------------ # Lookup & materialisation --------------------------------------------------
# Lookup & materialisation
# ------------------------------------------------------------------
def _materialise(self, ph: Placeholder) -> T: def _materialise(self, ph: Placeholder) -> T:
"""Materialize a placeholder using the module-level cached function.""" """Materialize a placeholder using the module-level cached function.
Args:
ph: Placeholder to materialize
Returns:
The materialized object, cast to type T
"""
return cast(T, _materialise_placeholder(ph)) return cast(T, _materialise_placeholder(ph))
def get(self, alias: str) -> T: def get(self, alias: str) -> T:
"""Retrieve an object by alias, materializing if needed.
Thread-safe lazy loading: if the alias points to a placeholder,
it will be loaded and cached before returning.
Args:
alias: The registered name to look up
Returns:
The registered object
Raises:
KeyError: If alias not found
TypeError: If materialized object doesn't match base_cls
ImportError/AttributeError: If lazy loading fails
"""
try: try:
target = self._objs[alias] target = self._objs[alias]
except KeyError as exc: except KeyError as exc:
...@@ -191,27 +321,36 @@ class Registry(Generic[T]): ...@@ -191,27 +321,36 @@ class Registry(Generic[T]):
) )
return target return target
# ------------------------------------------------------------------
# Mapping helpers
# ------------------------------------------------------------------
def __getitem__(self, alias: str) -> T: def __getitem__(self, alias: str) -> T:
"""Allow dict-style access: registry[alias]."""
return self.get(alias) return self.get(alias)
def __iter__(self): def __iter__(self):
"""Iterate over registered aliases."""
return iter(self._objs) return iter(self._objs)
def __len__(self): def __len__(self):
"""Return number of registered aliases."""
return len(self._objs) return len(self._objs)
def items(self): def items(self):
"""Return (alias, object) pairs.
Note: Objects may be placeholders that haven't been materialized yet.
"""
return self._objs.items() return self._objs.items()
# ------------------------------------------------------------------ # Utilities -------------------------------------------------------------
# Utilities
# ------------------------------------------------------------------
def origin(self, alias: str) -> str | None: def origin(self, alias: str) -> str | None:
"""Get the source location of a registered object.
Args:
alias: The registered name
Returns:
"path/to/file.py:line_number" or None if not available
"""
obj = self._objs.get(alias) obj = self._objs.get(alias)
if isinstance(obj, (str, md.EntryPoint)): if isinstance(obj, (str, md.EntryPoint)):
return None return None
...@@ -223,24 +362,41 @@ class Registry(Generic[T]): ...@@ -223,24 +362,41 @@ class Registry(Generic[T]):
return None return None
def freeze(self): def freeze(self):
"""Make the registry read-only to prevent further modifications.
After freezing, attempts to register new objects will fail.
This is useful for ensuring registry contents don't change after
initialization.
"""
with self._lock: with self._lock:
self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment] self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment]
# Test helper ------------------------------------------------------------- # Test helper --------------------------------
def _clear(self): # pragma: no cover def _clear(self): # pragma: no cover
"""Erase registry (for isolated tests).""" """Erase registry (for isolated tests).
Clears both the registry contents and the materialization cache.
Only use this in test code to ensure clean state between tests.
"""
self._objs.clear() self._objs.clear()
_materialise_placeholder.cache_clear() _materialise_placeholder.cache_clear()
# ──────────────────────────────────────────────────────────────────────── # Structured object for metrics ------------------
# Structured object for metrics
# ────────────────────────────────────────────────────────────────────────
@dataclass(frozen=True) @dataclass(frozen=True)
class MetricSpec: class MetricSpec:
"""Specification for a metric including computation and aggregation functions.
Attributes:
compute: Function to compute metric on individual items
aggregate: Function to aggregate multiple metric values into a single score
higher_is_better: Whether higher values indicate better performance
output_type: Optional type hint for the output (e.g., "generate_until" for perplexity)
requires: Optional list of other metrics this one depends on
"""
compute: Callable[[Any, Any], Any] compute: Callable[[Any, Any], Any]
aggregate: Callable[[Iterable[Any]], float] aggregate: Callable[[Iterable[Any]], float]
higher_is_better: bool = True higher_is_better: bool = True
...@@ -248,9 +404,7 @@ class MetricSpec: ...@@ -248,9 +404,7 @@ class MetricSpec:
requires: list[str] | None = None requires: list[str] | None = None
# ──────────────────────────────────────────────────────────────────────── # Canonical registries aliases ---------------------
# Canonical registries
# ────────────────────────────────────────────────────────────────────────
from lm_eval.api.model import LM # noqa: E402 from lm_eval.api.model import LM # noqa: E402
...@@ -264,7 +418,7 @@ metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry( ...@@ -264,7 +418,7 @@ metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
"metric aggregation" "metric aggregation"
) )
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag") higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[Callable] = Registry("filter") filter_registry: Registry[type[Filter]] = Registry("filter")
# Public helper aliases ------------------------------------------------------ # Public helper aliases ------------------------------------------------------
...@@ -281,7 +435,15 @@ get_filter = filter_registry.get ...@@ -281,7 +435,15 @@ get_filter = filter_registry.get
def _no_aggregation_fn(values: Iterable[Any]) -> float: def _no_aggregation_fn(values: Iterable[Any]) -> float:
"""Default aggregation that raises NotImplementedError.""" """Default aggregation that raises NotImplementedError.
Args:
values: Metric values to aggregate (unused)
Raises:
NotImplementedError: Always - this is a placeholder for metrics
that haven't specified an aggregation function
"""
raise NotImplementedError( raise NotImplementedError(
"No aggregation function specified for this metric. " "No aggregation function specified for this metric. "
"Please specify 'aggregation' parameter in @register_metric." "Please specify 'aggregation' parameter in @register_metric."
...@@ -289,6 +451,31 @@ def _no_aggregation_fn(values: Iterable[Any]) -> float: ...@@ -289,6 +451,31 @@ def _no_aggregation_fn(values: Iterable[Any]) -> float:
def register_metric(**kw): def register_metric(**kw):
"""Decorator for registering metric functions.
Creates a MetricSpec from the decorated function and keyword arguments,
then registers it in the metric registry.
Args:
**kw: Keyword arguments including:
- metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True)
- output_type: Optional output type hint
- requires: Optional list of required metrics
Returns:
Decorator function that registers the metric
Example:
>>> @register_metric(
... metric="my_accuracy",
... aggregation="mean",
... higher_is_better=True
... )
... def compute_accuracy(items):
... return sum(item["correct"] for item in items) / len(items)
"""
name = kw["metric"] name = kw["metric"]
def deco(fn): def deco(fn):
...@@ -312,6 +499,21 @@ def register_metric(**kw): ...@@ -312,6 +499,21 @@ def register_metric(**kw):
def get_metric(name, hf_evaluate_metric=False): def get_metric(name, hf_evaluate_metric=False):
"""Get a metric compute function by name.
First checks the local metric registry, then optionally falls back
to HuggingFace evaluate library.
Args:
name: Metric name to retrieve
hf_evaluate_metric: If True, suppress warning when falling back to HF
Returns:
The metric's compute function
Raises:
KeyError: If metric not found in registry or HF evaluate
"""
try: try:
spec = metric_registry.get(name) spec = metric_registry.get(name)
return spec.compute # type: ignore[attr-defined] return spec.compute # type: ignore[attr-defined]
...@@ -342,10 +544,13 @@ get_aggregation = metric_agg_registry.get ...@@ -342,10 +544,13 @@ get_aggregation = metric_agg_registry.get
DEFAULT_METRIC_REGISTRY = metric_registry DEFAULT_METRIC_REGISTRY = metric_registry
AGGREGATION_REGISTRY = metric_agg_registry AGGREGATION_REGISTRY = metric_agg_registry
# Convenience ----------------------------------------------------------------
def freeze_all(): def freeze_all():
"""Freeze all registries to prevent further modifications.
This is useful for ensuring registry contents are immutable after
initialization, preventing accidental modifications during runtime.
"""
for r in ( for r in (
model_registry, model_registry,
task_registry, task_registry,
...@@ -357,11 +562,11 @@ def freeze_all(): ...@@ -357,11 +562,11 @@ def freeze_all():
r.freeze() r.freeze()
# Backwards‑compat read‑only aliases ---------------------------------------- # Backwards‑compat aliases ----------------------------------------
MODEL_REGISTRY = model_registry # type: ignore MODEL_REGISTRY = model_registry
TASK_REGISTRY = task_registry # type: ignore TASK_REGISTRY = task_registry
METRIC_REGISTRY = metric_registry # type: ignore METRIC_REGISTRY = metric_registry
METRIC_AGGREGATION_REGISTRY = metric_agg_registry # type: ignore METRIC_AGGREGATION_REGISTRY = metric_agg_registry
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry # type: ignore HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry
FILTER_REGISTRY = filter_registry # type: ignore FILTER_REGISTRY = filter_registry
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment