Commit 124d3049 authored by Baber's avatar Baber
Browse files

better placeholder materialization

parent 9af24b7e
......@@ -5,7 +5,7 @@ import random
import re
import string
from collections.abc import Iterable, Sequence
from typing import Callable, List, Optional, TypeVar
from typing import Callable, Generic, List, Optional, TypeVar
import numpy as np
import sacrebleu
......@@ -451,7 +451,7 @@ def _sacreformat(refs, preds):
# stderr stuff
class _bootstrap_internal:
class _bootstrap_internal(Generic[T]):
"""
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
......
......@@ -3,11 +3,11 @@ from __future__ import annotations
import importlib
import inspect
import threading
from collections.abc import Iterable, Mapping
from collections.abc import Iterable
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from typing import Any, Callable, Generic, Type, TypeVar, Union, cast
from typing import Any, Callable, Generic, TypeVar, Union, cast
try:
......@@ -15,7 +15,6 @@ try:
except ImportError: # pragma: no cover – fallback for 3.8/3.9
import importlib_metadata as md # type: ignore
# Legacy exports (keep for one release, then drop)
LEGACY_EXPORTS = [
"DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY",
......@@ -52,14 +51,40 @@ __all__ = [
"higher_is_better_registry",
"filter_registry",
"freeze_all",
# legacy
*LEGACY_EXPORTS,
]
] # type: ignore
T = TypeVar("T")
Placeholder = Union[str, md.EntryPoint] # light‑weight lazy token
# ────────────────────────────────────────────────────────────────────────
# Module-level cache for materializing placeholders (prevents memory leak)
# ────────────────────────────────────────────────────────────────────────
@lru_cache(maxsize=16)
def _materialise_placeholder(ph: Placeholder) -> Any:
"""Materialize a lazy placeholder into the actual object.
This is at module level to avoid memory leaks from lru_cache on instance methods.
"""
if isinstance(ph, str):
mod, _, attr = ph.partition(":")
if not attr:
raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
return getattr(importlib.import_module(mod), attr)
return ph.load()
# ────────────────────────────────────────────────────────────────────────
# Metric-specific metadata storage
# ────────────────────────────────────────────────────────────────────────
_metric_meta: dict[str, dict[str, Any]] = {}
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────
......@@ -72,12 +97,11 @@ class Registry(Generic[T]):
self,
name: str,
*,
base_cls: Union[Type[T], None] = None,
base_cls: type[T] | None = None,
) -> None:
self._name = name
self._base_cls = base_cls
self._objs: dict[str, Union[T, Placeholder]] = {}
self._meta: dict[str, dict[str, Any]] = {}
self._objs: dict[str, T | Placeholder] = {}
self._lock = threading.RLock()
# ------------------------------------------------------------------
......@@ -87,24 +111,22 @@ class Registry(Generic[T]):
def register(
self,
*aliases: str,
lazy: Union[T, Placeholder, None] = None,
metadata: dict[str, Any] | None = None,
lazy: T | Placeholder | None = None,
) -> Callable[[T], T]:
"""``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``."""
def _store(alias: str, target: Union[T, Placeholder]) -> None:
def _store(alias: str, target: T | Placeholder) -> None:
current = self._objs.get(alias)
# ─── collision handling ────────────────────────────────────
if current is not None and current != target:
# allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type):
mod, _, cls = current.partition(":")
# mod, _, cls = current.partition(":")
if current == f"{target.__module__}:{target.__name__}":
self._objs[alias] = target
self._meta[alias] = metadata or {}
return
raise ValueError(
f"{self._name!r} alias '{alias}' already registered (" # noqa: B950
f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})"
)
# ─── type check for concrete classes ───────────────────────
......@@ -114,8 +136,6 @@ class Registry(Generic[T]):
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
self._objs[alias] = target
if metadata:
self._meta[alias] = metadata
def decorator(obj: T) -> T: # type: ignore[valid-type]
names = aliases or (getattr(obj, "__name__", str(obj)),)
......@@ -139,14 +159,9 @@ class Registry(Generic[T]):
# Lookup & materialisation
# ------------------------------------------------------------------
@lru_cache(maxsize=256)
def _materialise(self, ph: Placeholder) -> T:
if isinstance(ph, str):
mod, _, attr = ph.partition(":")
if not attr:
raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
return cast(T, getattr(importlib.import_module(mod), attr))
return cast(T, ph.load())
"""Materialize a placeholder using the module-level cached function."""
return cast(T, _materialise_placeholder(ph))
def get(self, alias: str) -> T:
try:
......@@ -162,7 +177,9 @@ class Registry(Generic[T]):
fresh = self._objs[alias]
if isinstance(fresh, (str, md.EntryPoint)):
concrete = self._materialise(fresh)
self._objs[alias] = concrete
# Only update if not frozen (MappingProxyType)
if not isinstance(self._objs, MappingProxyType):
self._objs[alias] = concrete
else:
concrete = fresh # another thread did the job
target = concrete
......@@ -178,26 +195,23 @@ class Registry(Generic[T]):
# Mapping helpers
# ------------------------------------------------------------------
def __getitem__(self, alias: str) -> T: # noqa: DunderImplemented
def __getitem__(self, alias: str) -> T:
return self.get(alias)
def __iter__(self): # noqa: DunderImplemented
def __iter__(self):
return iter(self._objs)
def __len__(self): # noqa: DunderImplemented
def __len__(self):
return len(self._objs)
def items(self): # noqa: DunderImplemented
def items(self):
return self._objs.items()
# ------------------------------------------------------------------
# Utilities
# ------------------------------------------------------------------
def metadata(self, alias: str) -> Union[Mapping[str, Any], None]:
return self._meta.get(alias)
def origin(self, alias: str) -> Union[str, None]:
def origin(self, alias: str) -> str | None:
obj = self._objs.get(alias)
if isinstance(obj, (str, md.EntryPoint)):
return None
......@@ -211,15 +225,13 @@ class Registry(Generic[T]):
def freeze(self):
with self._lock:
self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment]
self._meta = MappingProxyType(dict(self._meta)) # type: ignore[assignment]
# Test helper -------------------------------------------------------------
def _clear(self): # pragma: no cover
"""Erase registry (for isolated tests)."""
self._objs.clear()
self._meta.clear()
self._materialise.cache_clear()
_materialise_placeholder.cache_clear()
# ────────────────────────────────────────────────────────────────────────
......@@ -232,8 +244,8 @@ class MetricSpec:
compute: Callable[[Any, Any], Any]
aggregate: Callable[[Iterable[Any]], float]
higher_is_better: bool = True
output_type: Union[str, None] = None
requires: Union[list[str], None] = None
output_type: str | None = None
requires: list[str] | None = None
# ────────────────────────────────────────────────────────────────────────
......@@ -243,7 +255,9 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = Registry("model", base_cls=LM)
model_registry: Registry[type[LM]] = cast(
Registry[type[LM]], Registry("model", base_cls=LM)
)
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
......@@ -266,6 +280,14 @@ get_filter = filter_registry.get
# Metric helpers need thin wrappers to build MetricSpec ----------------------
def _no_aggregation_fn(values: Iterable[Any]) -> float:
"""Default aggregation that raises NotImplementedError."""
raise NotImplementedError(
"No aggregation function specified for this metric. "
"Please specify 'aggregation' parameter in @register_metric."
)
def register_metric(**kw):
name = kw["metric"]
......@@ -275,13 +297,14 @@ def register_metric(**kw):
aggregate=(
metric_agg_registry.get(kw["aggregation"])
if "aggregation" in kw
else lambda _: {}
else _no_aggregation_fn
),
higher_is_better=kw.get("higher_is_better", True),
output_type=kw.get("output_type"),
requires=kw.get("requires"),
)
metric_registry.register(name, lazy=spec, metadata=kw)
metric_registry.register(name, lazy=spec)
_metric_meta[name] = kw
higher_is_better_registry.register(name, lazy=spec.higher_is_better)
return fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment