registry.py 17.9 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Registry system for lm_eval components.

This module provides a centralized registration system for models, tasks, metrics,
filters, and other components in the lm_eval framework. The registry supports:

- Lazy loading with placeholders to improve startup time
- Type checking and validation
- Thread-safe registration and lookup
- Plugin discovery via entry points
- Backwards compatibility with legacy registration patterns

## Usage Examples

### Registering a Model
```python
from lm_eval.api.registry import register_model
from lm_eval.api.model import LM

@register_model("my-model")
class MyModel(LM):
    def __init__(self, **kwargs):
        ...
```

### Registering a Metric
```python
from lm_eval.api.registry import register_metric

@register_metric(
    metric="my_accuracy",
    aggregation="mean",
    higher_is_better=True
)
def my_accuracy_fn(items):
    ...
```

### Registering with Lazy Loading
```python
# Register without importing the actual implementation
model_registry.register("lazy-model", lazy="my_package.models:LazyModel")
```

### Looking up Components
```python
from lm_eval.api.registry import get_model, get_metric

# Get a model class
model_cls = get_model("gpt-j")
model = model_cls(**config)

# Get a metric function
metric_fn = get_metric("accuracy")
```
"""

Baber's avatar
Baber committed
57
58
59
60
61
from __future__ import annotations

import importlib
import inspect
import threading
Baber's avatar
Baber committed
62
from collections.abc import Iterable
Baber's avatar
Baber committed
63
64
65
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
Baber's avatar
Baber committed
66
from typing import Any, Callable, Generic, TypeVar, Union, cast
Baber's avatar
Baber committed
67

Baber's avatar
Baber committed
68
69
from lm_eval.api.filter import Filter

Baber's avatar
Baber committed
70

71
72
73
try:
    import importlib.metadata as md  # Python ≥3.10
except ImportError:  # pragma: no cover – fallback for 3.8/3.9
Baber's avatar
Baber committed
74
75
    import importlib_metadata as md  # type: ignore

Baber's avatar
Baber committed
76
LEGACY_EXPORTS = [
Baber's avatar
Baber committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    "DEFAULT_METRIC_REGISTRY",
    "AGGREGATION_REGISTRY",
    "register_model",
    "get_model",
    "register_task",
    "get_task",
    "register_metric",
    "get_metric",
    "register_metric_aggregation",
    "get_metric_aggregation",
    "register_higher_is_better",
    "is_higher_better",
    "register_filter",
    "get_filter",
    "register_aggregation",
    "get_aggregation",
    "MODEL_REGISTRY",
    "TASK_REGISTRY",
    "METRIC_REGISTRY",
    "METRIC_AGGREGATION_REGISTRY",
    "HIGHER_IS_BETTER_REGISTRY",
    "FILTER_REGISTRY",
]

Baber's avatar
Baber committed
101
102
103
104
105
106
107
108
109
110
111
112
__all__ = [
    # canonical
    "Registry",
    "MetricSpec",
    "model_registry",
    "task_registry",
    "metric_registry",
    "metric_agg_registry",
    "higher_is_better_registry",
    "filter_registry",
    "freeze_all",
    *LEGACY_EXPORTS,
Baber's avatar
Baber committed
113
]  # type: ignore
Baber's avatar
Baber committed
114

Baber's avatar
Baber committed
115
T = TypeVar("T")
Baber's avatar
Baber committed
116
Placeholder = Union[str, md.EntryPoint]
Baber's avatar
Baber committed
117
118
119
120
121
122
123


@lru_cache(maxsize=16)
def _materialise_placeholder(ph: Placeholder) -> Any:
    """Materialize a lazy placeholder into the actual object.

    This is at module level to avoid memory leaks from lru_cache on instance methods.
Baber's avatar
Baber committed
124
125
126
127
128
129
130
131
132
133
134

    Args:
        ph: Either a string path "module:object" or an EntryPoint instance

    Returns:
        The loaded object

    Raises:
        ValueError: If the string format is invalid
        ImportError: If the module cannot be imported
        AttributeError: If the object doesn't exist in the module
Baber's avatar
Baber committed
135
136
137
138
139
140
141
142
143
    """
    if isinstance(ph, str):
        mod, _, attr = ph.partition(":")
        if not attr:
            raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
        return getattr(importlib.import_module(mod), attr)
    return ph.load()


Baber's avatar
Baber committed
144
# Metric-specific metadata storage --------------------------------------------
Baber's avatar
Baber committed
145
146
147
148

_metric_meta: dict[str, dict[str, Any]] = {}


Baber's avatar
Baber committed
149
class Registry(Generic[T]):
Baber's avatar
Baber committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    """A thread-safe registry for named objects with lazy loading support.

    The Registry provides a central location for registering and retrieving
    components by name. It supports:

    - Direct registration of objects
    - Lazy registration with placeholders (strings or entry points)
    - Type checking against a base class
    - Thread-safe operations
    - Freezing to prevent further modifications

    Example:
        >>> from lm_eval.api.model import LM
        >>> registry = Registry("models", base_cls=LM)
        >>>
        >>> # Direct registration
        >>> @registry.register("my-model")
        >>> class MyModel(LM):
        ...     pass
        >>>
        >>> # Lazy registration
        >>> registry.register("lazy-model", lazy="mypackage:LazyModel")
        >>>
        >>> # Retrieval (triggers lazy loading if needed)
        >>> model_cls = registry.get("my-model")
        >>> model = model_cls()
    """
Baber's avatar
Baber committed
177
178
179
180
181

    def __init__(
        self,
        name: str,
        *,
Baber's avatar
Baber committed
182
        base_cls: type[T] | None = None,
Baber's avatar
Baber committed
183
    ) -> None:
Baber's avatar
Baber committed
184
185
186
187
188
189
        """Initialize a new registry.

        Args:
            name: Human-readable name for error messages (e.g., "model", "metric")
            base_cls: Optional base class that all registered objects must inherit from
        """
190
191
        self._name = name
        self._base_cls = base_cls
Baber's avatar
Baber committed
192
        self._objs: dict[str, T | Placeholder] = {}
Baber's avatar
Baber committed
193
194
        self._lock = threading.RLock()

Baber's avatar
Baber committed
195
    # Registration (decorator or direct call) --------------------------------------
Baber's avatar
Baber committed
196

Baber's avatar
Baber committed
197
198
199
    def register(
        self,
        *aliases: str,
Baber's avatar
Baber committed
200
        lazy: T | Placeholder | None = None,
201
    ) -> Callable[[T], T]:
Baber's avatar
Baber committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        """Register an object under one or more aliases.

        Can be used as a decorator or called directly for lazy registration.

        Args:
            *aliases: Names to register the object under. If empty, uses object's __name__
            lazy: For direct calls only - a placeholder string "module:object" or EntryPoint

        Returns:
            Decorator function (or no-op if lazy registration)

        Examples:
            >>> # As decorator
            >>> @model_registry.register("name1", "name2")
            >>> class MyModel(LM):
            ...     pass
            >>>
            >>> # Direct lazy registration
            >>> model_registry.register("lazy-name", lazy="mymodule:MyModel")

        Raises:
Baber's avatar
cleanup  
Baber committed
223
224
            ValueError: If alias is already registered with a different target
            TypeError: If an object doesn't inherit from base_cls (when specified)
Baber's avatar
Baber committed
225
        """
226

Baber's avatar
Baber committed
227
        def _store(alias: str, target: T | Placeholder) -> None:
228
            current = self._objs.get(alias)
Baber's avatar
Baber committed
229
            # collision handling ------------------------------------------
230
231
            if current is not None and current != target:
                # allow placeholder → real object upgrade
Baber's avatar
cleanup  
Baber committed
232
233
234
235
236
237
238
239
                # mod, _, cls = current.partition(":")
                if (
                    isinstance(current, str)
                    and isinstance(target, type)
                    and current == f"{target.__module__}:{target.__name__}"
                ):
                    self._objs[alias] = target
                    return
240
                raise ValueError(
Baber's avatar
Baber committed
241
                    f"{self._name!r} alias '{alias}' already registered ("
242
243
                    f"existing={current}, new={target})"
                )
Baber's avatar
Baber committed
244
            # type check for concrete classes ----------------------------------------------
Baber's avatar
cleanup  
Baber committed
245
246
247
248
249
250
251
252
            if (
                self._base_cls is not None
                and isinstance(target, type)
                and not issubclass(target, self._base_cls)
            ):
                raise TypeError(
                    f"{target} must inherit from {self._base_cls} to be a {self._name}"
                )
253
254
255
256
257
258
259
            self._objs[alias] = target

        def decorator(obj: T) -> T:  # type: ignore[valid-type]
            names = aliases or (getattr(obj, "__name__", str(obj)),)
            with self._lock:
                for name in names:
                    _store(name, obj)
260
            return obj
Baber's avatar
Baber committed
261

262
263
264
265
266
267
268
269
270
271
        # Direct call with *lazy* placeholder
        if lazy is not None:
            if len(aliases) != 1:
                raise ValueError("Exactly one alias required when using 'lazy='")
            with self._lock:
                _store(aliases[0], lazy)  # type: ignore[arg-type]
            # return no‑op decorator for accidental use
            return lambda x: x  # type: ignore[return-value]

        return decorator
Baber's avatar
Baber committed
272

Baber's avatar
Baber committed
273
    # Lookup & materialisation --------------------------------------------------
Baber's avatar
Baber committed
274

275
    def _materialise(self, ph: Placeholder) -> T:
Baber's avatar
Baber committed
276
277
278
279
280
281
282
283
        """Materialize a placeholder using the module-level cached function.

        Args:
            ph: Placeholder to materialize

        Returns:
            The materialized object, cast to type T
        """
Baber's avatar
Baber committed
284
        return cast(T, _materialise_placeholder(ph))
Baber's avatar
Baber committed
285
286

    def get(self, alias: str) -> T:
Baber's avatar
Baber committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        """Retrieve an object by alias, materializing if needed.

        Thread-safe lazy loading: if the alias points to a placeholder,
        it will be loaded and cached before returning.

        Args:
            alias: The registered name to look up

        Returns:
            The registered object

        Raises:
            KeyError: If alias not found
            TypeError: If materialized object doesn't match base_cls
            ImportError/AttributeError: If lazy loading fails
        """
303
304
305
306
307
308
309
310
311
312
313
314
315
        try:
            target = self._objs[alias]
        except KeyError as exc:
            raise KeyError(
                f"Unknown {self._name} '{alias}'. Available: {', '.join(self._objs)}"
            ) from exc

        if isinstance(target, (str, md.EntryPoint)):
            with self._lock:
                # Re‑check under lock (another thread might have resolved it)
                fresh = self._objs[alias]
                if isinstance(fresh, (str, md.EntryPoint)):
                    concrete = self._materialise(fresh)
Baber's avatar
Baber committed
316
317
318
                    # Only update if not frozen (MappingProxyType)
                    if not isinstance(self._objs, MappingProxyType):
                        self._objs[alias] = concrete
Baber's avatar
Baber committed
319
                else:
320
321
                    concrete = fresh  # another thread did the job
            target = concrete
322

323
324
325
326
327
328
        # Late type/validator checks
        if self._base_cls is not None and not issubclass(target, self._base_cls):  # type: ignore[arg-type]
            raise TypeError(
                f"{target} does not inherit from {self._base_cls} (alias '{alias}')"
            )
        return target
329

Baber's avatar
Baber committed
330
    def __getitem__(self, alias: str) -> T:
Baber's avatar
Baber committed
331
        """Allow dict-style access: registry[alias]."""
Baber's avatar
Baber committed
332
        return self.get(alias)
333

Baber's avatar
Baber committed
334
    def __iter__(self):
Baber's avatar
Baber committed
335
        """Iterate over registered aliases."""
336
        return iter(self._objs)
337

Baber's avatar
Baber committed
338
    def __len__(self):
Baber's avatar
Baber committed
339
        """Return number of registered aliases."""
340
        return len(self._objs)
341

Baber's avatar
Baber committed
342
    def items(self):
Baber's avatar
Baber committed
343
344
345
346
        """Return (alias, object) pairs.

        Note: Objects may be placeholders that haven't been materialized yet.
        """
347
        return self._objs.items()
348

Baber's avatar
Baber committed
349
    # Utilities -------------------------------------------------------------
350

Baber's avatar
Baber committed
351
    def origin(self, alias: str) -> str | None:
Baber's avatar
Baber committed
352
353
354
355
356
357
358
359
        """Get the source location of a registered object.

        Args:
            alias: The registered name

        Returns:
            "path/to/file.py:line_number" or None if not available
        """
360
361
362
        obj = self._objs.get(alias)
        if isinstance(obj, (str, md.EntryPoint)):
            return None
Baber's avatar
Baber committed
363
        try:
364
            path = inspect.getfile(obj)  # type: ignore[arg-type]
Baber's avatar
Baber committed
365
            line = inspect.getsourcelines(obj)[1]  # type: ignore[arg-type]
366
367
            return f"{path}:{line}"
        except Exception:  # pragma: no cover – best‑effort only
Baber's avatar
Baber committed
368
            return None
369

Baber's avatar
Baber committed
370
    def freeze(self):
Baber's avatar
Baber committed
371
372
373
374
375
376
        """Make the registry read-only to prevent further modifications.

        After freezing, attempts to register new objects will fail.
        This is useful for ensuring registry contents don't change after
        initialization.
        """
Baber's avatar
Baber committed
377
        with self._lock:
378
            self._objs = MappingProxyType(dict(self._objs))  # type: ignore[assignment]
379

Baber's avatar
Baber committed
380
    # Test helper --------------------------------
381
    def _clear(self):  # pragma: no cover
Baber's avatar
Baber committed
382
383
384
385
386
        """Erase registry (for isolated tests).

        Clears both the registry contents and the materialization cache.
        Only use this in test code to ensure clean state between tests.
        """
387
        self._objs.clear()
Baber's avatar
Baber committed
388
        _materialise_placeholder.cache_clear()
389
390


Baber's avatar
Baber committed
391
# Structured object for metrics ------------------
392
393


Baber's avatar
Baber committed
394
395
@dataclass(frozen=True)
class MetricSpec:
Baber's avatar
Baber committed
396
397
398
399
400
401
402
403
404
405
    """Specification for a metric including computation and aggregation functions.

    Attributes:
        compute: Function to compute metric on individual items
        aggregate: Function to aggregate multiple metric values into a single score
        higher_is_better: Whether higher values indicate better performance
        output_type: Optional type hint for the output (e.g., "generate_until" for perplexity)
        requires: Optional list of other metrics this one depends on
    """

Baber's avatar
Baber committed
406
    compute: Callable[[Any, Any], Any]
407
    aggregate: Callable[[Iterable[Any]], float]
Baber's avatar
Baber committed
408
    higher_is_better: bool = True
Baber's avatar
Baber committed
409
410
    output_type: str | None = None
    requires: list[str] | None = None
411
412


Baber's avatar
Baber committed
413
# Canonical registries aliases ---------------------
414

Baber's avatar
Baber committed
415
from lm_eval.api.model import LM  # noqa: E402
416
417


Baber's avatar
cleanup  
Baber committed
418
model_registry = cast(Registry[type[LM]], Registry("model", base_cls=LM))
Baber's avatar
Baber committed
419
420
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
421
422
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
    "metric aggregation"
Baber's avatar
Baber committed
423
424
)
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
Baber's avatar
Baber committed
425
filter_registry: Registry[type[Filter]] = Registry("filter")
426

427
# Public helper aliases ------------------------------------------------------
Baber's avatar
Baber committed
428

429
register_model = model_registry.register
Baber's avatar
Baber committed
430
431
get_model = model_registry.get

432
register_task = task_registry.register
Baber's avatar
Baber committed
433
434
get_task = task_registry.get

435
436
437
438
register_filter = filter_registry.register
get_filter = filter_registry.get

# Metric helpers need thin wrappers to build MetricSpec ----------------------
Baber's avatar
Baber committed
439

Baber's avatar
Baber committed
440

Baber's avatar
Baber committed
441
def _no_aggregation_fn(values: Iterable[Any]) -> float:
Baber's avatar
Baber committed
442
443
444
445
446
447
448
449
450
    """Default aggregation that raises NotImplementedError.

    Args:
        values: Metric values to aggregate (unused)

    Raises:
        NotImplementedError: Always - this is a placeholder for metrics
                           that haven't specified an aggregation function
    """
Baber's avatar
Baber committed
451
452
453
454
455
456
    raise NotImplementedError(
        "No aggregation function specified for this metric. "
        "Please specify 'aggregation' parameter in @register_metric."
    )


457
def register_metric(**kw):
Baber's avatar
Baber committed
458
459
460
461
462
463
    """Decorator for registering metric functions.

    Creates a MetricSpec from the decorated function and keyword arguments,
    then registers it in the metric registry.

    Args:
Baber's avatar
cleanup  
Baber committed
464
        **kw: Keyword arguments including
Baber's avatar
Baber committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
            - metric: Name to register the metric under (required)
            - aggregation: Name of aggregation function in metric_agg_registry
            - higher_is_better: Whether higher scores are better (default: True)
            - output_type: Optional output type hint
            - requires: Optional list of required metrics

    Returns:
        Decorator function that registers the metric

    Example:
        >>> @register_metric(
        ...     metric="my_accuracy",
        ...     aggregation="mean",
        ...     higher_is_better=True
        ... )
        ... def compute_accuracy(items):
        ...     return sum(item["correct"] for item in items) / len(items)
    """
483
    name = kw["metric"]
Baber's avatar
Baber committed
484

485
    def deco(fn):
Baber's avatar
Baber committed
486
487
        spec = MetricSpec(
            compute=fn,
488
489
490
            aggregate=(
                metric_agg_registry.get(kw["aggregation"])
                if "aggregation" in kw
Baber's avatar
Baber committed
491
                else _no_aggregation_fn
492
493
494
495
            ),
            higher_is_better=kw.get("higher_is_better", True),
            output_type=kw.get("output_type"),
            requires=kw.get("requires"),
Baber's avatar
Baber committed
496
        )
Baber's avatar
Baber committed
497
498
        metric_registry.register(name, lazy=spec)
        _metric_meta[name] = kw
499
        higher_is_better_registry.register(name, lazy=spec.higher_is_better)
500
501
        return fn

502
    return deco
503
504


505
def get_metric(name, hf_evaluate_metric=False):
Baber's avatar
Baber committed
506
507
508
509
510
511
512
513
514
515
516
517
518
    """Get a metric compute function by name.

    First checks the local metric registry, then optionally falls back
    to HuggingFace evaluate library.

    Args:
        name: Metric name to retrieve
        hf_evaluate_metric: If True, suppress warning when falling back to HF

    Returns:
        The metric's compute function

    Raises:
Baber's avatar
cleanup  
Baber committed
519
        KeyError: If a metric is not found in registry or HF evaluate
Baber's avatar
Baber committed
520
    """
521
522
523
524
525
    try:
        spec = metric_registry.get(name)
        return spec.compute  # type: ignore[attr-defined]
    except KeyError:
        if not hf_evaluate_metric:
Baber's avatar
Baber committed
526
527
528
            import logging

            logging.getLogger(__name__).warning(
529
                f"Metric '{name}' not in registry; trying HF evaluate…"
530
            )
531
532
        try:
            import evaluate as hf
Chris's avatar
Chris committed
533

534
535
            return hf.load(name).compute  # type: ignore[attr-defined]
        except Exception:
Baber's avatar
cleanup  
Baber committed
536
            raise KeyError(f"Metric '{name}' not found anywhere") from None
537

haileyschoelkopf's avatar
haileyschoelkopf committed
538

539
540
register_metric_aggregation = metric_agg_registry.register
get_metric_aggregation = metric_agg_registry.get
haileyschoelkopf's avatar
haileyschoelkopf committed
541

542
register_higher_is_better = higher_is_better_registry.register
Baber's avatar
Baber committed
543
is_higher_better = higher_is_better_registry.get
544

545
546
547
548
549
# Legacy compatibility
register_aggregation = metric_agg_registry.register
get_aggregation = metric_agg_registry.get
DEFAULT_METRIC_REGISTRY = metric_registry
AGGREGATION_REGISTRY = metric_agg_registry
Baber's avatar
Baber committed
550
551


552
def freeze_all():
Baber's avatar
Baber committed
553
554
555
556
557
    """Freeze all registries to prevent further modifications.

    This is useful for ensuring registry contents are immutable after
    initialization, preventing accidental modifications during runtime.
    """
558
    for r in (
Baber's avatar
Baber committed
559
560
561
562
563
564
565
        model_registry,
        task_registry,
        metric_registry,
        metric_agg_registry,
        higher_is_better_registry,
        filter_registry,
    ):
566
        r.freeze()
Baber's avatar
Baber committed
567
568


Baber's avatar
Baber committed
569
# Backwards‑compat aliases ----------------------------------------
Baber's avatar
Baber committed
570

Baber's avatar
Baber committed
571
572
573
574
575
576
MODEL_REGISTRY = model_registry
TASK_REGISTRY = task_registry
METRIC_REGISTRY = metric_registry
METRIC_AGGREGATION_REGISTRY = metric_agg_registry
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry
FILTER_REGISTRY = filter_registry