Commit 124d3049 authored by Baber's avatar Baber
Browse files

better placeholder materialization

parent 9af24b7e
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
import re import re
import string import string
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import Callable, List, Optional, TypeVar from typing import Callable, Generic, List, Optional, TypeVar
import numpy as np import numpy as np
import sacrebleu import sacrebleu
...@@ -451,7 +451,7 @@ def _sacreformat(refs, preds): ...@@ -451,7 +451,7 @@ def _sacreformat(refs, preds):
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal(Generic[T]):
""" """
Pool worker: `(i, xs)` → `n` bootstrap replicates Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`. of `f(xs)`using a RNG seeded with `i`.
......
...@@ -3,11 +3,11 @@ from __future__ import annotations ...@@ -3,11 +3,11 @@ from __future__ import annotations
import importlib import importlib
import inspect import inspect
import threading import threading
from collections.abc import Iterable, Mapping from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable, Generic, Type, TypeVar, Union, cast from typing import Any, Callable, Generic, TypeVar, Union, cast
try: try:
...@@ -15,7 +15,6 @@ try: ...@@ -15,7 +15,6 @@ try:
except ImportError: # pragma: no cover – fallback for 3.8/3.9 except ImportError: # pragma: no cover – fallback for 3.8/3.9
import importlib_metadata as md # type: ignore import importlib_metadata as md # type: ignore
# Legacy exports (keep for one release, then drop)
LEGACY_EXPORTS = [ LEGACY_EXPORTS = [
"DEFAULT_METRIC_REGISTRY", "DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY", "AGGREGATION_REGISTRY",
...@@ -52,14 +51,40 @@ __all__ = [ ...@@ -52,14 +51,40 @@ __all__ = [
"higher_is_better_registry", "higher_is_better_registry",
"filter_registry", "filter_registry",
"freeze_all", "freeze_all",
# legacy
*LEGACY_EXPORTS, *LEGACY_EXPORTS,
] ] # type: ignore
T = TypeVar("T") T = TypeVar("T")
Placeholder = Union[str, md.EntryPoint] # light‑weight lazy token Placeholder = Union[str, md.EntryPoint] # light‑weight lazy token
# ────────────────────────────────────────────────────────────────────────
# Module-level cache for materializing placeholders (prevents memory leak)
# ────────────────────────────────────────────────────────────────────────
@lru_cache(maxsize=16)
def _materialise_placeholder(ph: Placeholder) -> Any:
"""Materialize a lazy placeholder into the actual object.
This is at module level to avoid memory leaks from lru_cache on instance methods.
"""
if isinstance(ph, str):
mod, _, attr = ph.partition(":")
if not attr:
raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
return getattr(importlib.import_module(mod), attr)
return ph.load()
# ────────────────────────────────────────────────────────────────────────
# Metric-specific metadata storage
# ────────────────────────────────────────────────────────────────────────
_metric_meta: dict[str, dict[str, Any]] = {}
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
# Generic Registry # Generic Registry
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
...@@ -72,12 +97,11 @@ class Registry(Generic[T]): ...@@ -72,12 +97,11 @@ class Registry(Generic[T]):
self, self,
name: str, name: str,
*, *,
base_cls: Union[Type[T], None] = None, base_cls: type[T] | None = None,
) -> None: ) -> None:
self._name = name self._name = name
self._base_cls = base_cls self._base_cls = base_cls
self._objs: dict[str, Union[T, Placeholder]] = {} self._objs: dict[str, T | Placeholder] = {}
self._meta: dict[str, dict[str, Any]] = {}
self._lock = threading.RLock() self._lock = threading.RLock()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
...@@ -87,24 +111,22 @@ class Registry(Generic[T]): ...@@ -87,24 +111,22 @@ class Registry(Generic[T]):
def register( def register(
self, self,
*aliases: str, *aliases: str,
lazy: Union[T, Placeholder, None] = None, lazy: T | Placeholder | None = None,
metadata: dict[str, Any] | None = None,
) -> Callable[[T], T]: ) -> Callable[[T], T]:
"""``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``.""" """``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``."""
def _store(alias: str, target: Union[T, Placeholder]) -> None: def _store(alias: str, target: T | Placeholder) -> None:
current = self._objs.get(alias) current = self._objs.get(alias)
# ─── collision handling ──────────────────────────────────── # ─── collision handling ────────────────────────────────────
if current is not None and current != target: if current is not None and current != target:
# allow placeholder → real object upgrade # allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type): if isinstance(current, str) and isinstance(target, type):
mod, _, cls = current.partition(":") # mod, _, cls = current.partition(":")
if current == f"{target.__module__}:{target.__name__}": if current == f"{target.__module__}:{target.__name__}":
self._objs[alias] = target self._objs[alias] = target
self._meta[alias] = metadata or {}
return return
raise ValueError( raise ValueError(
f"{self._name!r} alias '{alias}' already registered (" # noqa: B950 f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})" f"existing={current}, new={target})"
) )
# ─── type check for concrete classes ─────────────────────── # ─── type check for concrete classes ───────────────────────
...@@ -114,8 +136,6 @@ class Registry(Generic[T]): ...@@ -114,8 +136,6 @@ class Registry(Generic[T]):
f"{target} must inherit from {self._base_cls} to be a {self._name}" f"{target} must inherit from {self._base_cls} to be a {self._name}"
) )
self._objs[alias] = target self._objs[alias] = target
if metadata:
self._meta[alias] = metadata
def decorator(obj: T) -> T: # type: ignore[valid-type] def decorator(obj: T) -> T: # type: ignore[valid-type]
names = aliases or (getattr(obj, "__name__", str(obj)),) names = aliases or (getattr(obj, "__name__", str(obj)),)
...@@ -139,14 +159,9 @@ class Registry(Generic[T]): ...@@ -139,14 +159,9 @@ class Registry(Generic[T]):
# Lookup & materialisation # Lookup & materialisation
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@lru_cache(maxsize=256)
def _materialise(self, ph: Placeholder) -> T: def _materialise(self, ph: Placeholder) -> T:
if isinstance(ph, str): """Materialize a placeholder using the module-level cached function."""
mod, _, attr = ph.partition(":") return cast(T, _materialise_placeholder(ph))
if not attr:
raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
return cast(T, getattr(importlib.import_module(mod), attr))
return cast(T, ph.load())
def get(self, alias: str) -> T: def get(self, alias: str) -> T:
try: try:
...@@ -162,6 +177,8 @@ class Registry(Generic[T]): ...@@ -162,6 +177,8 @@ class Registry(Generic[T]):
fresh = self._objs[alias] fresh = self._objs[alias]
if isinstance(fresh, (str, md.EntryPoint)): if isinstance(fresh, (str, md.EntryPoint)):
concrete = self._materialise(fresh) concrete = self._materialise(fresh)
# Only update if not frozen (MappingProxyType)
if not isinstance(self._objs, MappingProxyType):
self._objs[alias] = concrete self._objs[alias] = concrete
else: else:
concrete = fresh # another thread did the job concrete = fresh # another thread did the job
...@@ -178,26 +195,23 @@ class Registry(Generic[T]): ...@@ -178,26 +195,23 @@ class Registry(Generic[T]):
# Mapping helpers # Mapping helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def __getitem__(self, alias: str) -> T: # noqa: DunderImplemented def __getitem__(self, alias: str) -> T:
return self.get(alias) return self.get(alias)
def __iter__(self): # noqa: DunderImplemented def __iter__(self):
return iter(self._objs) return iter(self._objs)
def __len__(self): # noqa: DunderImplemented def __len__(self):
return len(self._objs) return len(self._objs)
def items(self): # noqa: DunderImplemented def items(self):
return self._objs.items() return self._objs.items()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Utilities # Utilities
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def metadata(self, alias: str) -> Union[Mapping[str, Any], None]: def origin(self, alias: str) -> str | None:
return self._meta.get(alias)
def origin(self, alias: str) -> Union[str, None]:
obj = self._objs.get(alias) obj = self._objs.get(alias)
if isinstance(obj, (str, md.EntryPoint)): if isinstance(obj, (str, md.EntryPoint)):
return None return None
...@@ -211,15 +225,13 @@ class Registry(Generic[T]): ...@@ -211,15 +225,13 @@ class Registry(Generic[T]):
def freeze(self): def freeze(self):
with self._lock: with self._lock:
self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment] self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment]
self._meta = MappingProxyType(dict(self._meta)) # type: ignore[assignment]
# Test helper ------------------------------------------------------------- # Test helper -------------------------------------------------------------
def _clear(self): # pragma: no cover def _clear(self): # pragma: no cover
"""Erase registry (for isolated tests).""" """Erase registry (for isolated tests)."""
self._objs.clear() self._objs.clear()
self._meta.clear() _materialise_placeholder.cache_clear()
self._materialise.cache_clear()
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
...@@ -232,8 +244,8 @@ class MetricSpec: ...@@ -232,8 +244,8 @@ class MetricSpec:
compute: Callable[[Any, Any], Any] compute: Callable[[Any, Any], Any]
aggregate: Callable[[Iterable[Any]], float] aggregate: Callable[[Iterable[Any]], float]
higher_is_better: bool = True higher_is_better: bool = True
output_type: Union[str, None] = None output_type: str | None = None
requires: Union[list[str], None] = None requires: list[str] | None = None
# ──────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────
...@@ -243,7 +255,9 @@ class MetricSpec: ...@@ -243,7 +255,9 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402 from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = Registry("model", base_cls=LM) model_registry: Registry[type[LM]] = cast(
Registry[type[LM]], Registry("model", base_cls=LM)
)
task_registry: Registry[Callable[..., Any]] = Registry("task") task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric") metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry( metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
...@@ -266,6 +280,14 @@ get_filter = filter_registry.get ...@@ -266,6 +280,14 @@ get_filter = filter_registry.get
# Metric helpers need thin wrappers to build MetricSpec ---------------------- # Metric helpers need thin wrappers to build MetricSpec ----------------------
def _no_aggregation_fn(values: Iterable[Any]) -> float:
"""Default aggregation that raises NotImplementedError."""
raise NotImplementedError(
"No aggregation function specified for this metric. "
"Please specify 'aggregation' parameter in @register_metric."
)
def register_metric(**kw): def register_metric(**kw):
name = kw["metric"] name = kw["metric"]
...@@ -275,13 +297,14 @@ def register_metric(**kw): ...@@ -275,13 +297,14 @@ def register_metric(**kw):
aggregate=( aggregate=(
metric_agg_registry.get(kw["aggregation"]) metric_agg_registry.get(kw["aggregation"])
if "aggregation" in kw if "aggregation" in kw
else lambda _: {} else _no_aggregation_fn
), ),
higher_is_better=kw.get("higher_is_better", True), higher_is_better=kw.get("higher_is_better", True),
output_type=kw.get("output_type"), output_type=kw.get("output_type"),
requires=kw.get("requires"), requires=kw.get("requires"),
) )
metric_registry.register(name, lazy=spec, metadata=kw) metric_registry.register(name, lazy=spec)
_metric_meta[name] = kw
higher_is_better_registry.register(name, lazy=spec.higher_is_better) higher_is_better_registry.register(name, lazy=spec.higher_is_better)
return fn return fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment