registry.py 12.4 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
from __future__ import annotations

import importlib
import inspect
import threading
6
from collections.abc import Iterable, Mapping
Baber's avatar
Baber committed
7
8
9
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
10
from typing import Any, Callable, Generic, Type, TypeVar, Union, cast
Baber's avatar
Baber committed
11
12


13
14
15
try:
    import importlib.metadata as md  # Python ≥3.10
except ImportError:  # pragma: no cover – fallback for 3.8/3.9
Baber's avatar
Baber committed
16
17
    import importlib_metadata as md  # type: ignore

Baber's avatar
Baber committed
18
19
# Legacy exports (keep for one release, then drop)
LEGACY_EXPORTS = [
Baber's avatar
Baber committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    "DEFAULT_METRIC_REGISTRY",
    "AGGREGATION_REGISTRY",
    "register_model",
    "get_model",
    "register_task",
    "get_task",
    "register_metric",
    "get_metric",
    "register_metric_aggregation",
    "get_metric_aggregation",
    "register_higher_is_better",
    "is_higher_better",
    "register_filter",
    "get_filter",
    "register_aggregation",
    "get_aggregation",
    "MODEL_REGISTRY",
    "TASK_REGISTRY",
    "METRIC_REGISTRY",
    "METRIC_AGGREGATION_REGISTRY",
    "HIGHER_IS_BETTER_REGISTRY",
    "FILTER_REGISTRY",
]

Baber's avatar
Baber committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
__all__ = [
    # canonical
    "Registry",
    "MetricSpec",
    "model_registry",
    "task_registry",
    "metric_registry",
    "metric_agg_registry",
    "higher_is_better_registry",
    "filter_registry",
    "freeze_all",
    # legacy
    *LEGACY_EXPORTS,
]

Baber's avatar
Baber committed
59
T = TypeVar("T")
60
Placeholder = Union[str, md.EntryPoint]  # light‑weight lazy token
Baber's avatar
Baber committed
61
62
63
64
65
66
67
68


# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────


class Registry(Generic[T]):
69
    """Name → object registry with optional lazy placeholders."""
Baber's avatar
Baber committed
70
71
72
73
74

    def __init__(
        self,
        name: str,
        *,
75
        base_cls: Union[Type[T], None] = None,
Baber's avatar
Baber committed
76
    ) -> None:
77
78
79
80
        self._name = name
        self._base_cls = base_cls
        self._objs: dict[str, Union[T, Placeholder]] = {}
        self._meta: dict[str, dict[str, Any]] = {}
Baber's avatar
Baber committed
81
82
83
        self._lock = threading.RLock()

    # ------------------------------------------------------------------
84
    # Registration (decorator or direct call)
Baber's avatar
Baber committed
85
86
    # ------------------------------------------------------------------

Baber's avatar
Baber committed
87
88
89
    def register(
        self,
        *aliases: str,
90
        lazy: Union[T, Placeholder, None] = None,
Baber's avatar
Baber committed
91
        metadata: dict[str, Any] | None = None,
92
    ) -> Callable[[T], T]:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        """``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``."""

        def _store(alias: str, target: Union[T, Placeholder]) -> None:
            current = self._objs.get(alias)
            # ─── collision handling ────────────────────────────────────
            if current is not None and current != target:
                # allow placeholder → real object upgrade
                if isinstance(current, str) and isinstance(target, type):
                    mod, _, cls = current.partition(":")
                    if current == f"{target.__module__}:{target.__name__}":
                        self._objs[alias] = target
                        self._meta[alias] = metadata or {}
                        return
                raise ValueError(
                    f"{self._name!r} alias '{alias}' already registered ("  # noqa: B950
                    f"existing={current}, new={target})"
                )
            # ─── type check for concrete classes ───────────────────────
            if self._base_cls is not None and isinstance(target, type):
                if not issubclass(target, self._base_cls):  # type: ignore[arg-type]
                    raise TypeError(
                        f"{target} must inherit from {self._base_cls} to be a {self._name}"
                    )
            self._objs[alias] = target
            if metadata:
                self._meta[alias] = metadata

        def decorator(obj: T) -> T:  # type: ignore[valid-type]
            names = aliases or (getattr(obj, "__name__", str(obj)),)
            with self._lock:
                for name in names:
                    _store(name, obj)
125
            return obj
Baber's avatar
Baber committed
126

127
128
129
130
131
132
133
134
135
136
        # Direct call with *lazy* placeholder
        if lazy is not None:
            if len(aliases) != 1:
                raise ValueError("Exactly one alias required when using 'lazy='")
            with self._lock:
                _store(aliases[0], lazy)  # type: ignore[arg-type]
            # return no‑op decorator for accidental use
            return lambda x: x  # type: ignore[return-value]

        return decorator
Baber's avatar
Baber committed
137
138
139
140
141

    # ------------------------------------------------------------------
    # Lookup & materialisation
    # ------------------------------------------------------------------

142
143
144
145
146
147
148
149
    @lru_cache(maxsize=256)
    def _materialise(self, ph: Placeholder) -> T:
        if isinstance(ph, str):
            mod, _, attr = ph.partition(":")
            if not attr:
                raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
            return cast(T, getattr(importlib.import_module(mod), attr))
        return cast(T, ph.load())
Baber's avatar
Baber committed
150
151

    def get(self, alias: str) -> T:
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        try:
            target = self._objs[alias]
        except KeyError as exc:
            raise KeyError(
                f"Unknown {self._name} '{alias}'. Available: {', '.join(self._objs)}"
            ) from exc

        if isinstance(target, (str, md.EntryPoint)):
            with self._lock:
                # Re‑check under lock (another thread might have resolved it)
                fresh = self._objs[alias]
                if isinstance(fresh, (str, md.EntryPoint)):
                    concrete = self._materialise(fresh)
                    self._objs[alias] = concrete
Baber's avatar
Baber committed
166
                else:
167
168
                    concrete = fresh  # another thread did the job
            target = concrete
169

170
171
172
173
174
175
        # Late type/validator checks
        if self._base_cls is not None and not issubclass(target, self._base_cls):  # type: ignore[arg-type]
            raise TypeError(
                f"{target} does not inherit from {self._base_cls} (alias '{alias}')"
            )
        return target
176

177
178
179
    # ------------------------------------------------------------------
    # Mapping helpers
    # ------------------------------------------------------------------
lintangsutawika's avatar
lintangsutawika committed
180

181
    def __getitem__(self, alias: str) -> T:  # noqa: DunderImplemented
Baber's avatar
Baber committed
182
        return self.get(alias)
183

184
185
    def __iter__(self):  # noqa: DunderImplemented
        return iter(self._objs)
186

187
188
    def __len__(self):  # noqa: DunderImplemented
        return len(self._objs)
189

190
191
    def items(self):  # noqa: DunderImplemented
        return self._objs.items()
192

193
194
195
    # ------------------------------------------------------------------
    # Utilities
    # ------------------------------------------------------------------
196

197
198
199
200
201
202
203
    def metadata(self, alias: str) -> Union[Mapping[str, Any], None]:
        return self._meta.get(alias)

    def origin(self, alias: str) -> Union[str, None]:
        obj = self._objs.get(alias)
        if isinstance(obj, (str, md.EntryPoint)):
            return None
Baber's avatar
Baber committed
204
        try:
205
            path = inspect.getfile(obj)  # type: ignore[arg-type]
Baber's avatar
Baber committed
206
            line = inspect.getsourcelines(obj)[1]  # type: ignore[arg-type]
207
208
            return f"{path}:{line}"
        except Exception:  # pragma: no cover – best‑effort only
Baber's avatar
Baber committed
209
            return None
210

Baber's avatar
Baber committed
211
212
    def freeze(self):
        with self._lock:
213
214
            self._objs = MappingProxyType(dict(self._objs))  # type: ignore[assignment]
            self._meta = MappingProxyType(dict(self._meta))  # type: ignore[assignment]
215

216
217
218
219
220
221
222
    # Test helper -------------------------------------------------------------

    def _clear(self):  # pragma: no cover
        """Erase registry (for isolated tests)."""
        self._objs.clear()
        self._meta.clear()
        self._materialise.cache_clear()
223
224


Baber's avatar
Baber committed
225
# ────────────────────────────────────────────────────────────────────────
226
# Structured object for metrics
Baber's avatar
Baber committed
227
# ────────────────────────────────────────────────────────────────────────
228
229


Baber's avatar
Baber committed
230
231
232
@dataclass(frozen=True)
class MetricSpec:
    compute: Callable[[Any, Any], Any]
233
    aggregate: Callable[[Iterable[Any]], float]
Baber's avatar
Baber committed
234
    higher_is_better: bool = True
235
236
    output_type: Union[str, None] = None
    requires: Union[list[str], None] = None
237
238


Baber's avatar
Baber committed
239
# ────────────────────────────────────────────────────────────────────────
240
# Canonical registries
Baber's avatar
Baber committed
241
# ────────────────────────────────────────────────────────────────────────
242

Baber's avatar
Baber committed
243
from lm_eval.api.model import LM  # noqa: E402
244
245


246
model_registry: Registry[type[LM]] = Registry("model", base_cls=LM)
Baber's avatar
Baber committed
247
248
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
249
250
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
    "metric aggregation"
Baber's avatar
Baber committed
251
252
253
)
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[Callable] = Registry("filter")
254

255
# Public helper aliases ------------------------------------------------------
Baber's avatar
Baber committed
256

257
register_model = model_registry.register
Baber's avatar
Baber committed
258
259
get_model = model_registry.get

260
register_task = task_registry.register
Baber's avatar
Baber committed
261
262
get_task = task_registry.get

263
264
265
266
register_filter = filter_registry.register
get_filter = filter_registry.get

# Metric helpers need thin wrappers to build MetricSpec ----------------------
Baber's avatar
Baber committed
267

Baber's avatar
Baber committed
268

269
270
def register_metric(**kw):
    name = kw["metric"]
Baber's avatar
Baber committed
271

272
    def deco(fn):
Baber's avatar
Baber committed
273
274
        spec = MetricSpec(
            compute=fn,
275
276
277
278
279
280
281
282
            aggregate=(
                metric_agg_registry.get(kw["aggregation"])
                if "aggregation" in kw
                else lambda _: {}
            ),
            higher_is_better=kw.get("higher_is_better", True),
            output_type=kw.get("output_type"),
            requires=kw.get("requires"),
Baber's avatar
Baber committed
283
        )
284
285
        metric_registry.register(name, lazy=spec, metadata=kw)
        higher_is_better_registry.register(name, lazy=spec.higher_is_better)
286
287
        return fn

288
    return deco
289
290


291
292
293
294
295
296
def get_metric(name, hf_evaluate_metric=False):
    try:
        spec = metric_registry.get(name)
        return spec.compute  # type: ignore[attr-defined]
    except KeyError:
        if not hf_evaluate_metric:
Baber's avatar
Baber committed
297
298
299
            import logging

            logging.getLogger(__name__).warning(
300
                f"Metric '{name}' not in registry; trying HF evaluate…"
301
            )
302
303
        try:
            import evaluate as hf
Chris's avatar
Chris committed
304

305
306
307
            return hf.load(name).compute  # type: ignore[attr-defined]
        except Exception:
            raise KeyError(f"Metric '{name}' not found anywhere")
308

haileyschoelkopf's avatar
haileyschoelkopf committed
309

310
311
register_metric_aggregation = metric_agg_registry.register
get_metric_aggregation = metric_agg_registry.get
haileyschoelkopf's avatar
haileyschoelkopf committed
312

313
register_higher_is_better = higher_is_better_registry.register
Baber's avatar
Baber committed
314
is_higher_better = higher_is_better_registry.get
315

316
317
318
319
320
# Legacy compatibility
register_aggregation = metric_agg_registry.register
get_aggregation = metric_agg_registry.get
DEFAULT_METRIC_REGISTRY = metric_registry
AGGREGATION_REGISTRY = metric_agg_registry
Baber's avatar
Baber committed
321

322
# Convenience ----------------------------------------------------------------
Baber's avatar
Baber committed
323
324


325
326
def freeze_all():
    for r in (
Baber's avatar
Baber committed
327
328
329
330
331
332
333
        model_registry,
        task_registry,
        metric_registry,
        metric_agg_registry,
        higher_is_better_registry,
        filter_registry,
    ):
334
        r.freeze()
Baber's avatar
Baber committed
335
336


337
# Backwards‑compat read‑only aliases ----------------------------------------
Baber's avatar
Baber committed
338

339
340
341
342
343
344
MODEL_REGISTRY = model_registry  # type: ignore
TASK_REGISTRY = task_registry  # type: ignore
METRIC_REGISTRY = metric_registry  # type: ignore
METRIC_AGGREGATION_REGISTRY = metric_agg_registry  # type: ignore
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry  # type: ignore
FILTER_REGISTRY = filter_registry  # type: ignore