registry.py 13.5 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
from __future__ import annotations

import importlib
import inspect
import threading
Baber's avatar
Baber committed
6
from collections.abc import Iterable
Baber's avatar
Baber committed
7
8
9
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
Baber's avatar
Baber committed
10
from typing import Any, Callable, Generic, TypeVar, Union, cast
Baber's avatar
Baber committed
11
12


13
14
15
try:
    import importlib.metadata as md  # Python ≥3.10
except ImportError:  # pragma: no cover – fallback for 3.8/3.9
Baber's avatar
Baber committed
16
17
    import importlib_metadata as md  # type: ignore

Baber's avatar
Baber committed
18
LEGACY_EXPORTS = [
Baber's avatar
Baber committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    "DEFAULT_METRIC_REGISTRY",
    "AGGREGATION_REGISTRY",
    "register_model",
    "get_model",
    "register_task",
    "get_task",
    "register_metric",
    "get_metric",
    "register_metric_aggregation",
    "get_metric_aggregation",
    "register_higher_is_better",
    "is_higher_better",
    "register_filter",
    "get_filter",
    "register_aggregation",
    "get_aggregation",
    "MODEL_REGISTRY",
    "TASK_REGISTRY",
    "METRIC_REGISTRY",
    "METRIC_AGGREGATION_REGISTRY",
    "HIGHER_IS_BETTER_REGISTRY",
    "FILTER_REGISTRY",
]

Baber's avatar
Baber committed
43
44
45
46
47
48
49
50
51
52
53
54
__all__ = [
    # canonical
    "Registry",
    "MetricSpec",
    "model_registry",
    "task_registry",
    "metric_registry",
    "metric_agg_registry",
    "higher_is_better_registry",
    "filter_registry",
    "freeze_all",
    *LEGACY_EXPORTS,
Baber's avatar
Baber committed
55
]  # type: ignore
Baber's avatar
Baber committed
56

Baber's avatar
Baber committed
57
T = TypeVar("T")
58
Placeholder = Union[str, md.EntryPoint]  # light‑weight lazy token
Baber's avatar
Baber committed
59
60


Baber's avatar
Baber committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# ────────────────────────────────────────────────────────────────────────
# Module-level cache for materializing placeholders (prevents memory leak)
# ────────────────────────────────────────────────────────────────────────


@lru_cache(maxsize=16)
def _materialise_placeholder(ph: Placeholder) -> Any:
    """Materialize a lazy placeholder into the actual object.

    This is at module level to avoid memory leaks from lru_cache on instance methods.
    """
    if isinstance(ph, str):
        mod, _, attr = ph.partition(":")
        if not attr:
            raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
        return getattr(importlib.import_module(mod), attr)
    return ph.load()


# ────────────────────────────────────────────────────────────────────────
# Metric-specific metadata storage
# ────────────────────────────────────────────────────────────────────────


_metric_meta: dict[str, dict[str, Any]] = {}


Baber's avatar
Baber committed
88
89
90
91
92
93
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────


class Registry(Generic[T]):
94
    """Name → object registry with optional lazy placeholders."""
Baber's avatar
Baber committed
95
96
97
98
99

    def __init__(
        self,
        name: str,
        *,
Baber's avatar
Baber committed
100
        base_cls: type[T] | None = None,
Baber's avatar
Baber committed
101
    ) -> None:
102
103
        self._name = name
        self._base_cls = base_cls
Baber's avatar
Baber committed
104
        self._objs: dict[str, T | Placeholder] = {}
Baber's avatar
Baber committed
105
106
107
        self._lock = threading.RLock()

    # ------------------------------------------------------------------
108
    # Registration (decorator or direct call)
Baber's avatar
Baber committed
109
110
    # ------------------------------------------------------------------

Baber's avatar
Baber committed
111
112
113
    def register(
        self,
        *aliases: str,
Baber's avatar
Baber committed
114
        lazy: T | Placeholder | None = None,
115
    ) -> Callable[[T], T]:
116
117
        """``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``."""

Baber's avatar
Baber committed
118
        def _store(alias: str, target: T | Placeholder) -> None:
119
120
121
122
123
            current = self._objs.get(alias)
            # ─── collision handling ────────────────────────────────────
            if current is not None and current != target:
                # allow placeholder → real object upgrade
                if isinstance(current, str) and isinstance(target, type):
Baber's avatar
Baber committed
124
                    # mod, _, cls = current.partition(":")
125
126
127
128
                    if current == f"{target.__module__}:{target.__name__}":
                        self._objs[alias] = target
                        return
                raise ValueError(
Baber's avatar
Baber committed
129
                    f"{self._name!r} alias '{alias}' already registered ("
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
                    f"existing={current}, new={target})"
                )
            # ─── type check for concrete classes ───────────────────────
            if self._base_cls is not None and isinstance(target, type):
                if not issubclass(target, self._base_cls):  # type: ignore[arg-type]
                    raise TypeError(
                        f"{target} must inherit from {self._base_cls} to be a {self._name}"
                    )
            self._objs[alias] = target

        def decorator(obj: T) -> T:  # type: ignore[valid-type]
            names = aliases or (getattr(obj, "__name__", str(obj)),)
            with self._lock:
                for name in names:
                    _store(name, obj)
145
            return obj
Baber's avatar
Baber committed
146

147
148
149
150
151
152
153
154
155
156
        # Direct call with *lazy* placeholder
        if lazy is not None:
            if len(aliases) != 1:
                raise ValueError("Exactly one alias required when using 'lazy='")
            with self._lock:
                _store(aliases[0], lazy)  # type: ignore[arg-type]
            # return no‑op decorator for accidental use
            return lambda x: x  # type: ignore[return-value]

        return decorator
Baber's avatar
Baber committed
157
158
159
160
161

    # ------------------------------------------------------------------
    # Lookup & materialisation
    # ------------------------------------------------------------------

162
    def _materialise(self, ph: Placeholder) -> T:
Baber's avatar
Baber committed
163
164
        """Materialize a placeholder using the module-level cached function."""
        return cast(T, _materialise_placeholder(ph))
Baber's avatar
Baber committed
165
166

    def get(self, alias: str) -> T:
167
168
169
170
171
172
173
174
175
176
177
178
179
        try:
            target = self._objs[alias]
        except KeyError as exc:
            raise KeyError(
                f"Unknown {self._name} '{alias}'. Available: {', '.join(self._objs)}"
            ) from exc

        if isinstance(target, (str, md.EntryPoint)):
            with self._lock:
                # Re‑check under lock (another thread might have resolved it)
                fresh = self._objs[alias]
                if isinstance(fresh, (str, md.EntryPoint)):
                    concrete = self._materialise(fresh)
Baber's avatar
Baber committed
180
181
182
                    # Only update if not frozen (MappingProxyType)
                    if not isinstance(self._objs, MappingProxyType):
                        self._objs[alias] = concrete
Baber's avatar
Baber committed
183
                else:
184
185
                    concrete = fresh  # another thread did the job
            target = concrete
186

187
188
189
190
191
192
        # Late type/validator checks
        if self._base_cls is not None and not issubclass(target, self._base_cls):  # type: ignore[arg-type]
            raise TypeError(
                f"{target} does not inherit from {self._base_cls} (alias '{alias}')"
            )
        return target
193

194
195
196
    # ------------------------------------------------------------------
    # Mapping helpers
    # ------------------------------------------------------------------
lintangsutawika's avatar
lintangsutawika committed
197

Baber's avatar
Baber committed
198
    def __getitem__(self, alias: str) -> T:
Baber's avatar
Baber committed
199
        return self.get(alias)
200

Baber's avatar
Baber committed
201
    def __iter__(self):
202
        return iter(self._objs)
203

Baber's avatar
Baber committed
204
    def __len__(self):
205
        return len(self._objs)
206

Baber's avatar
Baber committed
207
    def items(self):
208
        return self._objs.items()
209

210
211
212
    # ------------------------------------------------------------------
    # Utilities
    # ------------------------------------------------------------------
213

Baber's avatar
Baber committed
214
    def origin(self, alias: str) -> str | None:
215
216
217
        obj = self._objs.get(alias)
        if isinstance(obj, (str, md.EntryPoint)):
            return None
Baber's avatar
Baber committed
218
        try:
219
            path = inspect.getfile(obj)  # type: ignore[arg-type]
Baber's avatar
Baber committed
220
            line = inspect.getsourcelines(obj)[1]  # type: ignore[arg-type]
221
222
            return f"{path}:{line}"
        except Exception:  # pragma: no cover – best‑effort only
Baber's avatar
Baber committed
223
            return None
224

Baber's avatar
Baber committed
225
226
    def freeze(self):
        with self._lock:
227
            self._objs = MappingProxyType(dict(self._objs))  # type: ignore[assignment]
228

229
230
231
232
233
    # Test helper -------------------------------------------------------------

    def _clear(self):  # pragma: no cover
        """Erase registry (for isolated tests)."""
        self._objs.clear()
Baber's avatar
Baber committed
234
        _materialise_placeholder.cache_clear()
235
236


Baber's avatar
Baber committed
237
# ────────────────────────────────────────────────────────────────────────
238
# Structured object for metrics
Baber's avatar
Baber committed
239
# ────────────────────────────────────────────────────────────────────────
240
241


Baber's avatar
Baber committed
242
243
244
@dataclass(frozen=True)
class MetricSpec:
    compute: Callable[[Any, Any], Any]
245
    aggregate: Callable[[Iterable[Any]], float]
Baber's avatar
Baber committed
246
    higher_is_better: bool = True
Baber's avatar
Baber committed
247
248
    output_type: str | None = None
    requires: list[str] | None = None
249
250


Baber's avatar
Baber committed
251
# ────────────────────────────────────────────────────────────────────────
252
# Canonical registries
Baber's avatar
Baber committed
253
# ────────────────────────────────────────────────────────────────────────
254

Baber's avatar
Baber committed
255
from lm_eval.api.model import LM  # noqa: E402
256
257


Baber's avatar
Baber committed
258
259
260
model_registry: Registry[type[LM]] = cast(
    Registry[type[LM]], Registry("model", base_cls=LM)
)
Baber's avatar
Baber committed
261
262
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
263
264
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
    "metric aggregation"
Baber's avatar
Baber committed
265
266
267
)
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[Callable] = Registry("filter")
268

269
# Public helper aliases ------------------------------------------------------
Baber's avatar
Baber committed
270

271
register_model = model_registry.register
Baber's avatar
Baber committed
272
273
get_model = model_registry.get

274
register_task = task_registry.register
Baber's avatar
Baber committed
275
276
get_task = task_registry.get

277
278
279
280
register_filter = filter_registry.register
get_filter = filter_registry.get

# Metric helpers need thin wrappers to build MetricSpec ----------------------
Baber's avatar
Baber committed
281

Baber's avatar
Baber committed
282

Baber's avatar
Baber committed
283
284
285
286
287
288
289
290
def _no_aggregation_fn(values: Iterable[Any]) -> float:
    """Default aggregation that raises NotImplementedError."""
    raise NotImplementedError(
        "No aggregation function specified for this metric. "
        "Please specify 'aggregation' parameter in @register_metric."
    )


291
292
def register_metric(**kw):
    name = kw["metric"]
Baber's avatar
Baber committed
293

294
    def deco(fn):
Baber's avatar
Baber committed
295
296
        spec = MetricSpec(
            compute=fn,
297
298
299
            aggregate=(
                metric_agg_registry.get(kw["aggregation"])
                if "aggregation" in kw
Baber's avatar
Baber committed
300
                else _no_aggregation_fn
301
302
303
304
            ),
            higher_is_better=kw.get("higher_is_better", True),
            output_type=kw.get("output_type"),
            requires=kw.get("requires"),
Baber's avatar
Baber committed
305
        )
Baber's avatar
Baber committed
306
307
        metric_registry.register(name, lazy=spec)
        _metric_meta[name] = kw
308
        higher_is_better_registry.register(name, lazy=spec.higher_is_better)
309
310
        return fn

311
    return deco
312
313


314
315
316
317
318
319
def get_metric(name, hf_evaluate_metric=False):
    try:
        spec = metric_registry.get(name)
        return spec.compute  # type: ignore[attr-defined]
    except KeyError:
        if not hf_evaluate_metric:
Baber's avatar
Baber committed
320
321
322
            import logging

            logging.getLogger(__name__).warning(
323
                f"Metric '{name}' not in registry; trying HF evaluate…"
324
            )
325
326
        try:
            import evaluate as hf
Chris's avatar
Chris committed
327

328
329
330
            return hf.load(name).compute  # type: ignore[attr-defined]
        except Exception:
            raise KeyError(f"Metric '{name}' not found anywhere")
331

haileyschoelkopf's avatar
haileyschoelkopf committed
332

333
334
register_metric_aggregation = metric_agg_registry.register
get_metric_aggregation = metric_agg_registry.get
haileyschoelkopf's avatar
haileyschoelkopf committed
335

336
register_higher_is_better = higher_is_better_registry.register
Baber's avatar
Baber committed
337
is_higher_better = higher_is_better_registry.get
338

339
340
341
342
343
# Legacy compatibility
register_aggregation = metric_agg_registry.register
get_aggregation = metric_agg_registry.get
DEFAULT_METRIC_REGISTRY = metric_registry
AGGREGATION_REGISTRY = metric_agg_registry
Baber's avatar
Baber committed
344

345
# Convenience ----------------------------------------------------------------
Baber's avatar
Baber committed
346
347


348
349
def freeze_all():
    for r in (
Baber's avatar
Baber committed
350
351
352
353
354
355
356
        model_registry,
        task_registry,
        metric_registry,
        metric_agg_registry,
        higher_is_better_registry,
        filter_registry,
    ):
357
        r.freeze()
Baber's avatar
Baber committed
358
359


360
# Backwards‑compat read‑only aliases ----------------------------------------
Baber's avatar
Baber committed
361

362
363
364
365
366
367
MODEL_REGISTRY = model_registry  # type: ignore
TASK_REGISTRY = task_registry  # type: ignore
METRIC_REGISTRY = metric_registry  # type: ignore
METRIC_AGGREGATION_REGISTRY = metric_agg_registry  # type: ignore
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry  # type: ignore
FILTER_REGISTRY = filter_registry  # type: ignore