registry.py 17.8 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Registry system for lm_eval components.

This module provides a centralized registration system for models, tasks, metrics,
filters, and other components in the lm_eval framework. The registry supports:

- Lazy loading with placeholders to improve startup time
- Type checking and validation
- Thread-safe registration and lookup
- Plugin discovery via entry points
- Backwards compatibility with legacy registration patterns

## Usage Examples

### Registering a Model
```python
from lm_eval.api.registry import register_model
from lm_eval.api.model import LM

@register_model("my-model")
class MyModel(LM):
    def __init__(self, **kwargs):
        ...
```

### Registering a Metric
```python
from lm_eval.api.registry import register_metric

@register_metric(
    metric="my_accuracy",
    aggregation="mean",
    higher_is_better=True
)
def my_accuracy_fn(items):
    ...
```

### Registering with Lazy Loading
```python
# Register without importing the actual implementation
model_registry.register("lazy-model", lazy="my_package.models:LazyModel")
```

### Looking up Components
```python
from lm_eval.api.registry import get_model, get_metric

# Get a model class
model_cls = get_model("gpt-j")
model = model_cls(**config)

# Get a metric function
metric_fn = get_metric("accuracy")
```
"""

Baber's avatar
Baber committed
57
58
59
60
61
from __future__ import annotations

import importlib
import inspect
import threading
Baber's avatar
Baber committed
62
from collections.abc import Iterable
Baber's avatar
Baber committed
63
64
65
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
Baber's avatar
Baber committed
66
from typing import Any, Callable, Generic, TypeVar, Union, cast
Baber's avatar
Baber committed
67

Baber's avatar
Baber committed
68
69
from lm_eval.api.filter import Filter

Baber's avatar
Baber committed
70

71
72
73
try:
    import importlib.metadata as md  # Python ≥3.10
except ImportError:  # pragma: no cover – fallback for 3.8/3.9
Baber's avatar
Baber committed
74
75
    import importlib_metadata as md  # type: ignore

Baber's avatar
Baber committed
76
LEGACY_EXPORTS = [
Baber's avatar
Baber committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    "DEFAULT_METRIC_REGISTRY",
    "AGGREGATION_REGISTRY",
    "register_model",
    "get_model",
    "register_task",
    "get_task",
    "register_metric",
    "get_metric",
    "register_metric_aggregation",
    "get_metric_aggregation",
    "register_higher_is_better",
    "is_higher_better",
    "register_filter",
    "get_filter",
    "register_aggregation",
    "get_aggregation",
    "MODEL_REGISTRY",
    "TASK_REGISTRY",
    "METRIC_REGISTRY",
    "METRIC_AGGREGATION_REGISTRY",
    "HIGHER_IS_BETTER_REGISTRY",
    "FILTER_REGISTRY",
]

Baber's avatar
Baber committed
101
102
103
104
105
106
107
108
109
110
111
112
__all__ = [
    # canonical
    "Registry",
    "MetricSpec",
    "model_registry",
    "task_registry",
    "metric_registry",
    "metric_agg_registry",
    "higher_is_better_registry",
    "filter_registry",
    "freeze_all",
    *LEGACY_EXPORTS,
Baber's avatar
Baber committed
113
]  # type: ignore
Baber's avatar
Baber committed
114

Baber's avatar
Baber committed
115
T = TypeVar("T")
Baber's avatar
Baber committed
116
Placeholder = Union[str, md.EntryPoint]
Baber's avatar
Baber committed
117
118
119
120
121
122
123


@lru_cache(maxsize=16)
def _materialise_placeholder(ph: Placeholder) -> Any:
    """Materialize a lazy placeholder into the actual object.

    This is at module level to avoid memory leaks from lru_cache on instance methods.
Baber's avatar
Baber committed
124
125
126
127
128
129
130
131
132
133
134

    Args:
        ph: Either a string path "module:object" or an EntryPoint instance

    Returns:
        The loaded object

    Raises:
        ValueError: If the string format is invalid
        ImportError: If the module cannot be imported
        AttributeError: If the object doesn't exist in the module
Baber's avatar
Baber committed
135
136
137
138
139
140
141
142
143
    """
    if isinstance(ph, str):
        mod, _, attr = ph.partition(":")
        if not attr:
            raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
        return getattr(importlib.import_module(mod), attr)
    return ph.load()


Baber's avatar
Baber committed
144
# Metric-specific metadata storage --------------------------------------------
Baber's avatar
Baber committed
145
146
147
148

_metric_meta: dict[str, dict[str, Any]] = {}


Baber's avatar
Baber committed
149
class Registry(Generic[T]):
Baber's avatar
Baber committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    """A thread-safe registry for named objects with lazy loading support.

    The Registry provides a central location for registering and retrieving
    components by name. It supports:

    - Direct registration of objects
    - Lazy registration with placeholders (strings or entry points)
    - Type checking against a base class
    - Thread-safe operations
    - Freezing to prevent further modifications

    Example:
        >>> from lm_eval.api.model import LM
        >>> registry = Registry("models", base_cls=LM)
        >>>
        >>> # Direct registration
        >>> @registry.register("my-model")
        >>> class MyModel(LM):
        ...     pass
        >>>
        >>> # Lazy registration
        >>> registry.register("lazy-model", lazy="mypackage:LazyModel")
        >>>
        >>> # Retrieval (triggers lazy loading if needed)
        >>> model_cls = registry.get("my-model")
        >>> model = model_cls()
    """
Baber's avatar
Baber committed
177
178
179
180
181

    def __init__(
        self,
        name: str,
        *,
Baber's avatar
Baber committed
182
        base_cls: type[T] | None = None,
Baber's avatar
Baber committed
183
    ) -> None:
Baber's avatar
Baber committed
184
185
186
187
188
189
        """Initialize a new registry.

        Args:
            name: Human-readable name for error messages (e.g., "model", "metric")
            base_cls: Optional base class that all registered objects must inherit from
        """
190
191
        self._name = name
        self._base_cls = base_cls
Baber's avatar
Baber committed
192
        self._objs: dict[str, T | Placeholder] = {}
Baber's avatar
Baber committed
193
194
        self._lock = threading.RLock()

Baber's avatar
Baber committed
195
    # Registration (decorator or direct call) --------------------------------------
Baber's avatar
Baber committed
196

Baber's avatar
Baber committed
197
198
199
    def register(
        self,
        *aliases: str,
Baber's avatar
Baber committed
200
        lazy: T | Placeholder | None = None,
201
    ) -> Callable[[T], T]:
Baber's avatar
Baber committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        """Register an object under one or more aliases.

        Can be used as a decorator or called directly for lazy registration.

        Args:
            *aliases: Names to register the object under. If empty, uses object's __name__
            lazy: For direct calls only - a placeholder string "module:object" or EntryPoint

        Returns:
            Decorator function (or no-op if lazy registration)

        Examples:
            >>> # As decorator
            >>> @model_registry.register("name1", "name2")
            >>> class MyModel(LM):
            ...     pass
            >>>
            >>> # Direct lazy registration
            >>> model_registry.register("lazy-name", lazy="mymodule:MyModel")

        Raises:
            ValueError: If alias already registered with different target
            TypeError: If object doesn't inherit from base_cls (when specified)
        """
226

Baber's avatar
Baber committed
227
        def _store(alias: str, target: T | Placeholder) -> None:
228
            current = self._objs.get(alias)
Baber's avatar
Baber committed
229
            # collision handling ------------------------------------------
230
231
232
            if current is not None and current != target:
                # allow placeholder → real object upgrade
                if isinstance(current, str) and isinstance(target, type):
Baber's avatar
Baber committed
233
                    # mod, _, cls = current.partition(":")
234
235
236
237
                    if current == f"{target.__module__}:{target.__name__}":
                        self._objs[alias] = target
                        return
                raise ValueError(
Baber's avatar
Baber committed
238
                    f"{self._name!r} alias '{alias}' already registered ("
239
240
                    f"existing={current}, new={target})"
                )
Baber's avatar
Baber committed
241
            # type check for concrete classes ----------------------------------------------
242
243
244
245
246
247
248
249
250
251
252
253
            if self._base_cls is not None and isinstance(target, type):
                if not issubclass(target, self._base_cls):  # type: ignore[arg-type]
                    raise TypeError(
                        f"{target} must inherit from {self._base_cls} to be a {self._name}"
                    )
            self._objs[alias] = target

        def decorator(obj: T) -> T:  # type: ignore[valid-type]
            names = aliases or (getattr(obj, "__name__", str(obj)),)
            with self._lock:
                for name in names:
                    _store(name, obj)
254
            return obj
Baber's avatar
Baber committed
255

256
257
258
259
260
261
262
263
264
265
        # Direct call with *lazy* placeholder
        if lazy is not None:
            if len(aliases) != 1:
                raise ValueError("Exactly one alias required when using 'lazy='")
            with self._lock:
                _store(aliases[0], lazy)  # type: ignore[arg-type]
            # return no‑op decorator for accidental use
            return lambda x: x  # type: ignore[return-value]

        return decorator
Baber's avatar
Baber committed
266

Baber's avatar
Baber committed
267
    # Lookup & materialisation --------------------------------------------------
Baber's avatar
Baber committed
268

269
    def _materialise(self, ph: Placeholder) -> T:
Baber's avatar
Baber committed
270
271
272
273
274
275
276
277
        """Materialize a placeholder using the module-level cached function.

        Args:
            ph: Placeholder to materialize

        Returns:
            The materialized object, cast to type T
        """
Baber's avatar
Baber committed
278
        return cast(T, _materialise_placeholder(ph))
Baber's avatar
Baber committed
279
280

    def get(self, alias: str) -> T:
Baber's avatar
Baber committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        """Retrieve an object by alias, materializing if needed.

        Thread-safe lazy loading: if the alias points to a placeholder,
        it will be loaded and cached before returning.

        Args:
            alias: The registered name to look up

        Returns:
            The registered object

        Raises:
            KeyError: If alias not found
            TypeError: If materialized object doesn't match base_cls
            ImportError/AttributeError: If lazy loading fails
        """
297
298
299
300
301
302
303
304
305
306
307
308
309
        try:
            target = self._objs[alias]
        except KeyError as exc:
            raise KeyError(
                f"Unknown {self._name} '{alias}'. Available: {', '.join(self._objs)}"
            ) from exc

        if isinstance(target, (str, md.EntryPoint)):
            with self._lock:
                # Re‑check under lock (another thread might have resolved it)
                fresh = self._objs[alias]
                if isinstance(fresh, (str, md.EntryPoint)):
                    concrete = self._materialise(fresh)
Baber's avatar
Baber committed
310
311
312
                    # Only update if not frozen (MappingProxyType)
                    if not isinstance(self._objs, MappingProxyType):
                        self._objs[alias] = concrete
Baber's avatar
Baber committed
313
                else:
314
315
                    concrete = fresh  # another thread did the job
            target = concrete
316

317
318
319
320
321
322
        # Late type/validator checks
        if self._base_cls is not None and not issubclass(target, self._base_cls):  # type: ignore[arg-type]
            raise TypeError(
                f"{target} does not inherit from {self._base_cls} (alias '{alias}')"
            )
        return target
323

Baber's avatar
Baber committed
324
    def __getitem__(self, alias: str) -> T:
Baber's avatar
Baber committed
325
        """Allow dict-style access: registry[alias]."""
Baber's avatar
Baber committed
326
        return self.get(alias)
327

Baber's avatar
Baber committed
328
    def __iter__(self):
Baber's avatar
Baber committed
329
        """Iterate over registered aliases."""
330
        return iter(self._objs)
331

Baber's avatar
Baber committed
332
    def __len__(self):
Baber's avatar
Baber committed
333
        """Return number of registered aliases."""
334
        return len(self._objs)
335

Baber's avatar
Baber committed
336
    def items(self):
Baber's avatar
Baber committed
337
338
339
340
        """Return (alias, object) pairs.

        Note: Objects may be placeholders that haven't been materialized yet.
        """
341
        return self._objs.items()
342

Baber's avatar
Baber committed
343
    # Utilities -------------------------------------------------------------
344

Baber's avatar
Baber committed
345
    def origin(self, alias: str) -> str | None:
Baber's avatar
Baber committed
346
347
348
349
350
351
352
353
        """Get the source location of a registered object.

        Args:
            alias: The registered name

        Returns:
            "path/to/file.py:line_number" or None if not available
        """
354
355
356
        obj = self._objs.get(alias)
        if isinstance(obj, (str, md.EntryPoint)):
            return None
Baber's avatar
Baber committed
357
        try:
358
            path = inspect.getfile(obj)  # type: ignore[arg-type]
Baber's avatar
Baber committed
359
            line = inspect.getsourcelines(obj)[1]  # type: ignore[arg-type]
360
361
            return f"{path}:{line}"
        except Exception:  # pragma: no cover – best‑effort only
Baber's avatar
Baber committed
362
            return None
363

Baber's avatar
Baber committed
364
    def freeze(self):
Baber's avatar
Baber committed
365
366
367
368
369
370
        """Make the registry read-only to prevent further modifications.

        After freezing, attempts to register new objects will fail.
        This is useful for ensuring registry contents don't change after
        initialization.
        """
Baber's avatar
Baber committed
371
        with self._lock:
372
            self._objs = MappingProxyType(dict(self._objs))  # type: ignore[assignment]
373

Baber's avatar
Baber committed
374
    # Test helper --------------------------------
375
    def _clear(self):  # pragma: no cover
Baber's avatar
Baber committed
376
377
378
379
380
        """Erase registry (for isolated tests).

        Clears both the registry contents and the materialization cache.
        Only use this in test code to ensure clean state between tests.
        """
381
        self._objs.clear()
Baber's avatar
Baber committed
382
        _materialise_placeholder.cache_clear()
383
384


Baber's avatar
Baber committed
385
# Structured object for metrics ------------------
386
387


Baber's avatar
Baber committed
388
389
@dataclass(frozen=True)
class MetricSpec:
Baber's avatar
Baber committed
390
391
392
393
394
395
396
397
398
399
    """Specification for a metric including computation and aggregation functions.

    Attributes:
        compute: Function to compute metric on individual items
        aggregate: Function to aggregate multiple metric values into a single score
        higher_is_better: Whether higher values indicate better performance
        output_type: Optional type hint for the output (e.g., "generate_until" for perplexity)
        requires: Optional list of other metrics this one depends on
    """

Baber's avatar
Baber committed
400
    compute: Callable[[Any, Any], Any]
401
    aggregate: Callable[[Iterable[Any]], float]
Baber's avatar
Baber committed
402
    higher_is_better: bool = True
Baber's avatar
Baber committed
403
404
    output_type: str | None = None
    requires: list[str] | None = None
405
406


Baber's avatar
Baber committed
407
# Canonical registries aliases ---------------------
408

Baber's avatar
Baber committed
409
from lm_eval.api.model import LM  # noqa: E402
410
411


Baber's avatar
Baber committed
412
413
414
model_registry: Registry[type[LM]] = cast(
    Registry[type[LM]], Registry("model", base_cls=LM)
)
Baber's avatar
Baber committed
415
416
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
417
418
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
    "metric aggregation"
Baber's avatar
Baber committed
419
420
)
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
Baber's avatar
Baber committed
421
filter_registry: Registry[type[Filter]] = Registry("filter")
422

423
# Public helper aliases ------------------------------------------------------
Baber's avatar
Baber committed
424

425
register_model = model_registry.register
Baber's avatar
Baber committed
426
427
get_model = model_registry.get

428
register_task = task_registry.register
Baber's avatar
Baber committed
429
430
get_task = task_registry.get

431
432
433
434
register_filter = filter_registry.register
get_filter = filter_registry.get

# Metric helpers need thin wrappers to build MetricSpec ----------------------
Baber's avatar
Baber committed
435

Baber's avatar
Baber committed
436

Baber's avatar
Baber committed
437
def _no_aggregation_fn(values: Iterable[Any]) -> float:
Baber's avatar
Baber committed
438
439
440
441
442
443
444
445
446
    """Default aggregation that raises NotImplementedError.

    Args:
        values: Metric values to aggregate (unused)

    Raises:
        NotImplementedError: Always - this is a placeholder for metrics
                           that haven't specified an aggregation function
    """
Baber's avatar
Baber committed
447
448
449
450
451
452
    raise NotImplementedError(
        "No aggregation function specified for this metric. "
        "Please specify 'aggregation' parameter in @register_metric."
    )


453
def register_metric(**kw):
Baber's avatar
Baber committed
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
    """Decorator for registering metric functions.

    Creates a MetricSpec from the decorated function and keyword arguments,
    then registers it in the metric registry.

    Args:
        **kw: Keyword arguments including:
            - metric: Name to register the metric under (required)
            - aggregation: Name of aggregation function in metric_agg_registry
            - higher_is_better: Whether higher scores are better (default: True)
            - output_type: Optional output type hint
            - requires: Optional list of required metrics

    Returns:
        Decorator function that registers the metric

    Example:
        >>> @register_metric(
        ...     metric="my_accuracy",
        ...     aggregation="mean",
        ...     higher_is_better=True
        ... )
        ... def compute_accuracy(items):
        ...     return sum(item["correct"] for item in items) / len(items)
    """
479
    name = kw["metric"]
Baber's avatar
Baber committed
480

481
    def deco(fn):
Baber's avatar
Baber committed
482
483
        spec = MetricSpec(
            compute=fn,
484
485
486
            aggregate=(
                metric_agg_registry.get(kw["aggregation"])
                if "aggregation" in kw
Baber's avatar
Baber committed
487
                else _no_aggregation_fn
488
489
490
491
            ),
            higher_is_better=kw.get("higher_is_better", True),
            output_type=kw.get("output_type"),
            requires=kw.get("requires"),
Baber's avatar
Baber committed
492
        )
Baber's avatar
Baber committed
493
494
        metric_registry.register(name, lazy=spec)
        _metric_meta[name] = kw
495
        higher_is_better_registry.register(name, lazy=spec.higher_is_better)
496
497
        return fn

498
    return deco
499
500


501
def get_metric(name, hf_evaluate_metric=False):
Baber's avatar
Baber committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    """Get a metric compute function by name.

    First checks the local metric registry, then optionally falls back
    to HuggingFace evaluate library.

    Args:
        name: Metric name to retrieve
        hf_evaluate_metric: If True, suppress warning when falling back to HF

    Returns:
        The metric's compute function

    Raises:
        KeyError: If metric not found in registry or HF evaluate
    """
517
518
519
520
521
    try:
        spec = metric_registry.get(name)
        return spec.compute  # type: ignore[attr-defined]
    except KeyError:
        if not hf_evaluate_metric:
Baber's avatar
Baber committed
522
523
524
            import logging

            logging.getLogger(__name__).warning(
525
                f"Metric '{name}' not in registry; trying HF evaluate…"
526
            )
527
528
        try:
            import evaluate as hf
Chris's avatar
Chris committed
529

530
531
532
            return hf.load(name).compute  # type: ignore[attr-defined]
        except Exception:
            raise KeyError(f"Metric '{name}' not found anywhere")
533

haileyschoelkopf's avatar
haileyschoelkopf committed
534

535
536
register_metric_aggregation = metric_agg_registry.register
get_metric_aggregation = metric_agg_registry.get
haileyschoelkopf's avatar
haileyschoelkopf committed
537

538
register_higher_is_better = higher_is_better_registry.register
Baber's avatar
Baber committed
539
is_higher_better = higher_is_better_registry.get
540

541
542
543
544
545
# Legacy compatibility
register_aggregation = metric_agg_registry.register
get_aggregation = metric_agg_registry.get
DEFAULT_METRIC_REGISTRY = metric_registry
AGGREGATION_REGISTRY = metric_agg_registry
Baber's avatar
Baber committed
546
547


548
def freeze_all():
Baber's avatar
Baber committed
549
550
551
552
553
    """Freeze all registries to prevent further modifications.

    This is useful for ensuring registry contents are immutable after
    initialization, preventing accidental modifications during runtime.
    """
554
    for r in (
Baber's avatar
Baber committed
555
556
557
558
559
560
561
        model_registry,
        task_registry,
        metric_registry,
        metric_agg_registry,
        higher_is_better_registry,
        filter_registry,
    ):
562
        r.freeze()
Baber's avatar
Baber committed
563
564


Baber's avatar
Baber committed
565
# Backwards‑compat aliases ----------------------------------------
Baber's avatar
Baber committed
566

Baber's avatar
Baber committed
567
568
569
570
571
572
MODEL_REGISTRY = model_registry
TASK_REGISTRY = task_registry
METRIC_REGISTRY = metric_registry
METRIC_AGGREGATION_REGISTRY = metric_agg_registry
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry
FILTER_REGISTRY = filter_registry