Unverified Commit 70314843 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Merge pull request #3189 from EleutherAI/lazy_reg

refactor registry
parents 73202a2e 930b4253
from .api import metrics, model, registry # initializes the registries
from .filters import *
__version__ = "0.4.9.1"
......
"""Registry system for lm_eval components.
This module provides a centralized registration system for models, tasks, metrics,
filters, and other components in the lm_eval framework. The registry supports:
- Lazy loading with placeholders to improve startup time
- Type checking and validation
- Thread-safe registration and lookup
- Plugin discovery via entry points
- Backwards compatibility with legacy registration patterns
## Usage Examples
### Registering a Model
```python
from lm_eval.api.registry import register_model
from lm_eval.api.model import LM
@register_model("my-model")
class MyModel(LM):
def __init__(self, **kwargs):
...
```
### Registering a Metric
```python
from lm_eval.api.registry import register_metric
@register_metric(
metric="my_accuracy",
aggregation="mean",
higher_is_better=True
)
def my_accuracy_fn(items):
...
```
### Registering with Lazy Loading
```python
# Register without importing the actual implementation
model_registry.register("lazy-model", lazy="my_package.models:LazyModel")
```
### Looking up Components
```python
from lm_eval.api.registry import get_model, get_metric
# Get a model class
model_cls = get_model("gpt-j")
model = model_cls(**config)
# Get a metric function
metric_fn = get_metric("accuracy")
```
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable
import importlib
import inspect
import threading
from collections.abc import Iterable
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from typing import Any, Callable, Generic, TypeVar, Union, cast
from lm_eval.api.filter import Filter
try:
import importlib.metadata as md # Python ≥3.10
except ImportError: # pragma: no cover – fallback for 3.8/3.9
import importlib_metadata as md # type: ignore
LEGACY_EXPORTS = [
"DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY",
"register_model",
"get_model",
"register_task",
"get_task",
"register_metric",
"get_metric",
"register_metric_aggregation",
"get_metric_aggregation",
"register_higher_is_better",
"is_higher_better",
"register_filter",
"get_filter",
"register_aggregation",
"get_aggregation",
"MODEL_REGISTRY",
"TASK_REGISTRY",
"METRIC_REGISTRY",
"METRIC_AGGREGATION_REGISTRY",
"HIGHER_IS_BETTER_REGISTRY",
"FILTER_REGISTRY",
]
__all__ = [
# canonical
"Registry",
"MetricSpec",
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
"freeze_all",
*LEGACY_EXPORTS,
] # type: ignore
T = TypeVar("T")
Placeholder = Union[str, md.EntryPoint]
@lru_cache(maxsize=16)
def _materialise_placeholder(ph: Placeholder) -> Any:
"""Materialize a lazy placeholder into the actual object.
This is at module level to avoid memory leaks from lru_cache on instance methods.
Args:
ph: Either a string path "module:object" or an EntryPoint instance
Returns:
The loaded object
Raises:
ValueError: If the string format is invalid
ImportError: If the module cannot be imported
AttributeError: If the object doesn't exist in the module
"""
if isinstance(ph, str):
mod, _, attr = ph.partition(":")
if not attr:
raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
return getattr(importlib.import_module(mod), attr)
return ph.load()
# Metric-specific metadata storage --------------------------------------------
_metric_meta: dict[str, dict[str, Any]] = {}
class Registry(Generic[T]):
"""A thread-safe registry for named objects with lazy loading support.
The Registry provides a central location for registering and retrieving
components by name. It supports:
- Direct registration of objects
- Lazy registration with placeholders (strings or entry points)
- Type checking against a base class
- Thread-safe operations
- Freezing to prevent further modifications
Example:
>>> from lm_eval.api.model import LM
>>> registry = Registry("models", base_cls=LM)
>>>
>>> # Direct registration
>>> @registry.register("my-model")
>>> class MyModel(LM):
... pass
>>>
>>> # Lazy registration
>>> registry.register("lazy-model", lazy="mypackage:LazyModel")
>>>
>>> # Retrieval (triggers lazy loading if needed)
>>> model_cls = registry.get("my-model")
>>> model = model_cls()
"""
def __init__(
self,
name: str,
*,
base_cls: type[T] | None = None,
) -> None:
"""Initialize a new registry.
Args:
name: Human-readable name for error messages (e.g., "model", "metric")
base_cls: Optional base class that all registered objects must inherit from
"""
self._name = name
self._base_cls = base_cls
self._objs: dict[str, T | Placeholder] = {}
self._lock = threading.RLock()
# Registration (decorator or direct call) --------------------------------------
def register(
self,
*aliases: str,
lazy: T | Placeholder | None = None,
) -> Callable[[T], T]:
"""Register an object under one or more aliases.
Can be used as a decorator or called directly for lazy registration.
Args:
*aliases: Names to register the object under. If empty, uses object's __name__
lazy: For direct calls only - a placeholder string "module:object" or EntryPoint
Returns:
Decorator function (or no-op if lazy registration)
Examples:
>>> # As decorator
>>> @model_registry.register("name1", "name2")
>>> class MyModel(LM):
... pass
>>>
>>> # Direct lazy registration
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises:
ValueError: If alias already registered with different target
TypeError: If object doesn't inherit from base_cls (when specified)
"""
def _store(alias: str, target: T | Placeholder) -> None:
current = self._objs.get(alias)
# collision handling ------------------------------------------
if current is not None and current != target:
# allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type):
# mod, _, cls = current.partition(":")
if current == f"{target.__module__}:{target.__name__}":
self._objs[alias] = target
return
raise ValueError(
f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})"
)
# type check for concrete classes ----------------------------------------------
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
self._objs[alias] = target
def decorator(obj: T) -> T: # type: ignore[valid-type]
names = aliases or (getattr(obj, "__name__", str(obj)),)
with self._lock:
for name in names:
_store(name, obj)
return obj
# Direct call with *lazy* placeholder
if lazy is not None:
if len(aliases) != 1:
raise ValueError("Exactly one alias required when using 'lazy='")
with self._lock:
_store(aliases[0], lazy) # type: ignore[arg-type]
# return no‑op decorator for accidental use
return lambda x: x # type: ignore[return-value]
return decorator
# Lookup & materialisation --------------------------------------------------
def _materialise(self, ph: Placeholder) -> T:
"""Materialize a placeholder using the module-level cached function.
Args:
ph: Placeholder to materialize
Returns:
The materialized object, cast to type T
"""
return cast(T, _materialise_placeholder(ph))
def get(self, alias: str) -> T:
"""Retrieve an object by alias, materializing if needed.
Thread-safe lazy loading: if the alias points to a placeholder,
it will be loaded and cached before returning.
Args:
alias: The registered name to look up
Returns:
The registered object
Raises:
KeyError: If alias not found
TypeError: If materialized object doesn't match base_cls
ImportError/AttributeError: If lazy loading fails
"""
try:
target = self._objs[alias]
except KeyError as exc:
raise KeyError(
f"Unknown {self._name} '{alias}'. Available: {', '.join(self._objs)}"
) from exc
if isinstance(target, (str, md.EntryPoint)):
with self._lock:
# Re‑check under lock (another thread might have resolved it)
fresh = self._objs[alias]
if isinstance(fresh, (str, md.EntryPoint)):
concrete = self._materialise(fresh)
# Only update if not frozen (MappingProxyType)
if not isinstance(self._objs, MappingProxyType):
self._objs[alias] = concrete
else:
concrete = fresh # another thread did the job
target = concrete
# Late type/validator checks
if self._base_cls is not None and not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} does not inherit from {self._base_cls} (alias '{alias}')"
)
return target
if TYPE_CHECKING:
from lm_eval.api.model import LM
def __getitem__(self, alias: str) -> T:
"""Allow dict-style access: registry[alias]."""
return self.get(alias)
eval_logger = logging.getLogger(__name__)
def __iter__(self):
"""Iterate over registered aliases."""
return iter(self._objs)
MODEL_REGISTRY = {}
DEFAULTS = {
"model": {"max_length": 2048},
"tasks": {"generate_until": {"max_gen_toks": 256}},
}
def __len__(self):
"""Return number of registered aliases."""
return len(self._objs)
def items(self):
"""Return (alias, object) pairs.
def register_model(*names):
from lm_eval.api.model import LM
Note: Objects may be placeholders that haven't been materialized yet.
"""
return self._objs.items()
# either pass a list or a single alias.
# function receives them as a tuple of strings
# Utilities -------------------------------------------------------------
def decorate(cls):
for name in names:
assert issubclass(cls, LM), (
f"Model '{name}' ({cls.__name__}) must extend LM class"
)
def origin(self, alias: str) -> str | None:
"""Get the source location of a registered object.
assert name not in MODEL_REGISTRY, (
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
)
MODEL_REGISTRY[name] = cls
return cls
Args:
alias: The registered name
return decorate
Returns:
"path/to/file.py:line_number" or None if not available
"""
obj = self._objs.get(alias)
if isinstance(obj, (str, md.EntryPoint)):
return None
try:
path = inspect.getfile(obj) # type: ignore[arg-type]
line = inspect.getsourcelines(obj)[1] # type: ignore[arg-type]
return f"{path}:{line}"
except Exception: # pragma: no cover – best‑effort only
return None
def freeze(self):
"""Make the registry read-only to prevent further modifications.
def get_model(model_name: str) -> type[LM]:
try:
return MODEL_REGISTRY[model_name]
except KeyError as err:
available_models = ", ".join(MODEL_REGISTRY.keys())
raise KeyError(
f"Model '{model_name}' not found. Available models: {available_models}"
) from err
After freezing, attempts to register new objects will fail.
This is useful for ensuring registry contents don't change after
initialization.
"""
with self._lock:
self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment]
# Test helper --------------------------------
def _clear(self): # pragma: no cover
"""Erase registry (for isolated tests).
TASK_REGISTRY = {}
GROUP_REGISTRY = {}
ALL_TASKS = set()
func2task_index = {}
Clears both the registry contents and the materialization cache.
Only use this in test code to ensure clean state between tests.
"""
self._objs.clear()
_materialise_placeholder.cache_clear()
def register_task(name: str):
def decorate(fn):
assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!"
# Structured object for metrics ------------------
@dataclass(frozen=True)
class MetricSpec:
"""Specification for a metric including computation and aggregation functions.
Attributes:
compute: Function to compute metric on individual items
aggregate: Function to aggregate multiple metric values into a single score
higher_is_better: Whether higher values indicate better performance
output_type: Optional type hint for the output (e.g., "generate_until" for perplexity)
requires: Optional list of other metrics this one depends on
"""
compute: Callable[[Any, Any], Any]
aggregate: Callable[[Iterable[Any]], float]
higher_is_better: bool = True
output_type: str | None = None
requires: list[str] | None = None
# Canonical registries aliases ---------------------
from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = cast(
Registry[type[LM]], Registry("model", base_cls=LM)
)
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
"metric aggregation"
)
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[type[Filter]] = Registry("filter")
# Public helper aliases ------------------------------------------------------
register_model = model_registry.register
get_model = model_registry.get
register_task = task_registry.register
get_task = task_registry.get
register_filter = filter_registry.register
get_filter = filter_registry.get
# Metric helpers need thin wrappers to build MetricSpec ----------------------
def _no_aggregation_fn(values: Iterable[Any]) -> float:
"""Default aggregation that raises NotImplementedError.
Args:
values: Metric values to aggregate (unused)
Raises:
NotImplementedError: Always - this is a placeholder for metrics
that haven't specified an aggregation function
"""
raise NotImplementedError(
"No aggregation function specified for this metric. "
"Please specify 'aggregation' parameter in @register_metric."
)
def register_metric(**kw):
"""Decorator for registering metric functions.
Creates a MetricSpec from the decorated function and keyword arguments,
then registers it in the metric registry.
Args:
**kw: Keyword arguments including:
- metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True)
- output_type: Optional output type hint
- requires: Optional list of required metrics
Returns:
Decorator function that registers the metric
Example:
>>> @register_metric(
... metric="my_accuracy",
... aggregation="mean",
... higher_is_better=True
... )
... def compute_accuracy(items):
... return sum(item["correct"] for item in items) / len(items)
"""
name = kw["metric"]
def deco(fn):
spec = MetricSpec(
compute=fn,
aggregate=(
metric_agg_registry.get(kw["aggregation"])
if "aggregation" in kw
else _no_aggregation_fn
),
higher_is_better=kw.get("higher_is_better", True),
output_type=kw.get("output_type"),
requires=kw.get("requires"),
)
TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name
metric_registry.register(name, lazy=spec)
_metric_meta[name] = kw
higher_is_better_registry.register(name, lazy=spec.higher_is_better)
return fn
return decorate
return deco
def register_group(name):
def decorate(fn):
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn
def get_metric(name, hf_evaluate_metric=False):
"""Get a metric compute function by name.
return decorate
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: dict[str, Callable[[], dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
}
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert value not in registry, (
f"{key} named '{value}' conflicts with existing registered {key}!"
)
First checks the local metric registry, then optionally falls back
to HuggingFace evaluate library.
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
return fn
Args:
name: Metric name to retrieve
hf_evaluate_metric: If True, suppress warning when falling back to HF
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable[..., Any] | None:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
eval_logger.warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
)
Returns:
The metric's compute function
Raises:
KeyError: If metric not found in registry or HF evaluate
"""
try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name)
return metric_object.compute
except Exception:
eval_logger.error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name: str):
def decorate(fn):
assert name not in AGGREGATION_REGISTRY, (
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY[name] = fn
return fn
spec = metric_registry.get(name)
return spec.compute # type: ignore[attr-defined]
except KeyError:
if not hf_evaluate_metric:
import logging
return decorate
logging.getLogger(__name__).warning(
f"Metric '{name}' not in registry; trying HF evaluate…"
)
try:
import evaluate as hf
return hf.load(name).compute # type: ignore[attr-defined]
except Exception:
raise KeyError(f"Metric '{name}' not found anywhere")
def get_aggregation(name: str) -> Callable[..., Any] | None:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
register_metric_aggregation = metric_agg_registry.register
get_metric_aggregation = metric_agg_registry.get
def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable[..., Any]]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
f"{name} metric is not assigned a default aggregation!. Using default aggregation mean"
)
return AGGREGATION_REGISTRY["mean"]
register_higher_is_better = higher_is_better_registry.register
is_higher_better = higher_is_better_registry.get
# Legacy compatibility
register_aggregation = metric_agg_registry.register
get_aggregation = metric_agg_registry.get
DEFAULT_METRIC_REGISTRY = metric_registry
AGGREGATION_REGISTRY = metric_agg_registry
def is_higher_better(metric_name: str) -> bool:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!. Will default to True."
)
return True
def freeze_all():
"""Freeze all registries to prevent further modifications.
def register_filter(name: str):
def decorate(cls):
if name in FILTER_REGISTRY:
eval_logger.info(
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
)
FILTER_REGISTRY[name] = cls
return cls
This is useful for ensuring registry contents are immutable after
initialization, preventing accidental modifications during runtime.
"""
for r in (
model_registry,
task_registry,
metric_registry,
metric_agg_registry,
higher_is_better_registry,
filter_registry,
):
r.freeze()
return decorate
# Backwards‑compat aliases ----------------------------------------
def get_filter(filter_name: str | Callable) -> Callable:
try:
return FILTER_REGISTRY[filter_name]
except KeyError as e:
if callable(filter_name):
return filter_name
else:
eval_logger.warning(f"filter `{filter_name}` is not registered!")
raise e
MODEL_REGISTRY = model_registry
TASK_REGISTRY = task_registry
METRIC_REGISTRY = metric_registry
METRIC_AGGREGATION_REGISTRY = metric_agg_registry
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry
FILTER_REGISTRY = filter_registry
from __future__ import annotations
from functools import partial
from typing import Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
from lm_eval.api.registry import filter_registry, get_filter
from . import custom, extraction, selection, transformation
def build_filter_ensemble(
filter_name: str,
components: list[tuple[str, Optional[dict[str, Union[str, int, float]]]]],
components: list[tuple[str, dict[str, str | int | float] | None]],
) -> FilterEnsemble:
"""
Create a filtering pipeline.
......@@ -21,3 +21,12 @@ def build_filter_ensemble(
partial(get_filter(func), **(kwargs or {})) for func, kwargs in components
],
)
__all__ = [
"custom",
"extraction",
"selection",
"transformation",
"build_filter_ensemble",
]
from . import (
anthropic_llms,
api_models,
dummy,
gguf,
hf_audiolm,
hf_steered,
hf_vlms,
huggingface,
ibm_watsonx_ai,
mamba_lm,
nemo_lm,
neuron_optimum,
openai_completions,
optimum_ipex,
optimum_lm,
sglang_causallms,
sglang_generate_API,
textsynth,
vllm_causallms,
vllm_vlms,
)
# TODO: implement __all__
# Models are now lazily loaded via the registry system
# No need to import them all at once
# Define model mappings for lazy registration
MODEL_MAPPING = {
"anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM",
"anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM",
"anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM",
"local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI",
"local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion",
"openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI",
"openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion",
"dummy": "lm_eval.models.dummy:DummyLM",
"gguf": "lm_eval.models.gguf:GGUFLM",
"ggml": "lm_eval.models.gguf:GGUFLM",
"hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM",
"steered": "lm_eval.models.hf_steered:SteeredHF",
"hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM",
"hf-auto": "lm_eval.models.huggingface:HFLM",
"hf": "lm_eval.models.huggingface:HFLM",
"huggingface": "lm_eval.models.huggingface:HFLM",
"watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI",
"mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper",
"nemo_lm": "lm_eval.models.nemo_lm:NeMoLM",
"neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM",
"ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM",
"openvino": "lm_eval.models.optimum_lm:OptimumLM",
"sglang": "lm_eval.models.sglang_causallms:SGLANG",
"sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI",
"textsynth": "lm_eval.models.textsynth:TextSynthLM",
"vllm": "lm_eval.models.vllm_causallms:VLLM",
"vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM",
}
# Register all models lazily
def _register_all_models():
"""Register all known models lazily in the registry."""
from lm_eval.api.registry import model_registry
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Register the lazy placeholder using lazy parameter
model_registry.register(name, lazy=path)
# Call registration on module import
_register_all_models()
__all__ = ["MODEL_MAPPING"]
try:
......
from collections.abc import Generator
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Callable, Generator, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from peft.peft_model import PeftModel
......
......@@ -3,7 +3,7 @@ import json
import logging
import os
import warnings
from functools import lru_cache
from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
......@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg)
@lru_cache(maxsize=None)
@cache
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
......
......@@ -42,7 +42,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError:
pass
print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING:
pass
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -121,7 +121,7 @@ lint.fixable = ["I001", "F401", "UP"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117", "E741"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401", "F402", "F403"]
"__init__.py" = ["F401", "F402", "F403", "F405"]
[tool.ruff.lint.isort]
combine-as-imports = true
......
......@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from tqdm import tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registryv2 import ALL_TASKS
eval_logger = logging.getLogger(__name__)
......
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""
import threading
import pytest
from lm_eval.api.model import LM
from lm_eval.api.registry import (
MetricSpec,
Registry,
get_metric,
metric_agg_registry,
metric_registry,
model_registry,
register_metric,
)
# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics
class TestBasicRegistry:
"""Test basic registry functionality."""
def test_create_registry(self):
"""Test creating a basic registry."""
reg = Registry("test")
assert len(reg) == 0
assert list(reg) == []
def test_decorator_registration(self):
"""Test decorator-based registration."""
reg = Registry("test")
@reg.register("my_class")
class MyClass:
pass
assert "my_class" in reg
assert reg.get("my_class") == MyClass
assert reg["my_class"] == MyClass
def test_decorator_multiple_aliases(self):
"""Test decorator with multiple aliases."""
reg = Registry("test")
@reg.register("alias1", "alias2", "alias3")
class MyClass:
pass
assert reg.get("alias1") == MyClass
assert reg.get("alias2") == MyClass
assert reg.get("alias3") == MyClass
def test_decorator_auto_name(self):
"""Test decorator using class name when no alias provided."""
reg = Registry("test")
@reg.register()
class AutoNamedClass:
pass
assert reg.get("AutoNamedClass") == AutoNamedClass
def test_lazy_registration(self):
"""Test lazy loading with module paths."""
reg = Registry("test")
# Register with lazy loading
reg.register("join", lazy="os.path:join")
# Check it's stored as a string
assert isinstance(reg._objs["join"], str)
# Access triggers materialization
result = reg.get("join")
import os
assert result == os.path.join
assert callable(result)
def test_direct_registration(self):
"""Test direct object registration."""
reg = Registry("test")
class DirectClass:
pass
obj = DirectClass()
reg.register("direct", lazy=obj)
assert reg.get("direct") == obj
def test_metadata_removed(self):
"""Test that metadata parameter is removed from generic registry."""
reg = Registry("test")
# Should work without metadata parameter
@reg.register("test_class")
class TestClass:
pass
assert "test_class" in reg
assert reg.get("test_class") == TestClass
def test_unknown_key_error(self):
"""Test error when accessing unknown key."""
reg = Registry("test")
with pytest.raises(KeyError) as exc_info:
reg.get("unknown")
assert "Unknown test 'unknown'" in str(exc_info.value)
assert "Available:" in str(exc_info.value)
def test_iteration(self):
"""Test registry iteration."""
reg = Registry("test")
reg.register("a", lazy="os:getcwd")
reg.register("b", lazy="os:getenv")
reg.register("c", lazy="os:getpid")
assert list(reg) == ["a", "b", "c"]
assert len(reg) == 3
# Test items()
items = list(reg.items())
assert len(items) == 3
assert items[0][0] == "a"
assert isinstance(items[0][1], str) # Still lazy
def test_mapping_protocol(self):
"""Test that registry implements mapping protocol."""
reg = Registry("test")
reg.register("test", lazy="os:getcwd")
# __getitem__
assert reg["test"] == reg.get("test")
# __contains__
assert "test" in reg
assert "missing" not in reg
# __iter__ and __len__ tested above
class TestTypeConstraints:
"""Test type checking and base class constraints."""
def test_base_class_constraint(self):
"""Test base class validation."""
# Define a base class
class BaseClass:
pass
class GoodSubclass(BaseClass):
pass
class BadClass:
pass
reg = Registry("typed", base_cls=BaseClass)
# Should work - correct subclass
@reg.register("good")
class GoodInline(BaseClass):
pass
# Should fail - wrong type
with pytest.raises(TypeError) as exc_info:
@reg.register("bad")
class BadInline:
pass
assert "must inherit from" in str(exc_info.value)
def test_lazy_type_check(self):
"""Test that type checking happens on materialization for lazy entries."""
class BaseClass:
pass
reg = Registry("typed", base_cls=BaseClass)
# Register a lazy entry that will fail type check
reg.register("bad_lazy", lazy="os.path:join")
# Should fail when accessed - the error message varies
with pytest.raises(TypeError):
reg.get("bad_lazy")
class TestCollisionHandling:
"""Test registration collision scenarios."""
def test_identical_registration(self):
"""Test that identical re-registration is allowed."""
reg = Registry("test")
class MyClass:
pass
# First registration
reg.register("test", lazy=MyClass)
# Identical re-registration should work
reg.register("test", lazy=MyClass)
assert reg.get("test") == MyClass
def test_different_registration_fails(self):
"""Test that different re-registration fails."""
reg = Registry("test")
class Class1:
pass
class Class2:
pass
reg.register("test", lazy=Class1)
with pytest.raises(ValueError) as exc_info:
reg.register("test", lazy=Class2)
assert "already registered" in str(exc_info.value)
def test_lazy_to_concrete_upgrade(self):
"""Test that lazy placeholder can be upgraded to concrete class."""
reg = Registry("test")
# Register lazy
reg.register("myclass", lazy="test_registry:MyUpgradeClass")
# Define and register concrete - should work
@reg.register("myclass")
class MyUpgradeClass:
pass
assert reg.get("myclass") == MyUpgradeClass
class TestThreadSafety:
"""Test thread safety of registry operations."""
def test_concurrent_access(self):
"""Test concurrent access to lazy entries."""
reg = Registry("test")
# Register lazy entry
reg.register("concurrent", lazy="os.path:join")
results = []
errors = []
def access_item():
try:
result = reg.get("concurrent")
results.append(result)
except Exception as e:
errors.append(str(e))
# Launch threads
threads = []
for _ in range(10):
t = threading.Thread(target=access_item)
threads.append(t)
t.start()
# Wait for completion
for t in threads:
t.join()
# Check results
assert len(errors) == 0
assert len(results) == 10
# All should get the same object
assert all(r == results[0] for r in results)
def test_concurrent_registration(self):
"""Test concurrent registration doesn't cause issues."""
reg = Registry("test")
errors = []
def register_item(name, value):
try:
reg.register(name, lazy=value)
except Exception as e:
errors.append(str(e))
# Launch threads with different registrations
threads = []
for i in range(10):
t = threading.Thread(
target=register_item, args=(f"item_{i}", f"module{i}:Class{i}")
)
threads.append(t)
t.start()
# Wait for completion
for t in threads:
t.join()
# Check results
assert len(errors) == 0
assert len(reg) == 10
class TestMetricRegistry:
"""Test metric-specific registry functionality."""
def test_metric_spec(self):
"""Test MetricSpec dataclass."""
def compute_fn(items):
return [1 for _ in items]
def agg_fn(values):
return sum(values) / len(values)
spec = MetricSpec(
compute=compute_fn,
aggregate=agg_fn,
higher_is_better=True,
output_type="probability",
)
assert spec.compute == compute_fn
assert spec.aggregate == agg_fn
assert spec.higher_is_better
assert spec.output_type == "probability"
def test_register_metric_decorator(self):
"""Test @register_metric decorator."""
# Register aggregation function first
@metric_agg_registry.register("test_mean")
def test_mean(values):
return sum(values) / len(values) if values else 0.0
# Register metric
@register_metric(
metric="test_accuracy",
aggregation="test_mean",
higher_is_better=True,
output_type="accuracy",
)
def compute_accuracy(items):
return [1 if item["pred"] == item["gold"] else 0 for item in items]
# Check registration
assert "test_accuracy" in metric_registry
spec = metric_registry.get("test_accuracy")
assert isinstance(spec, MetricSpec)
assert spec.higher_is_better
assert spec.output_type == "accuracy"
# Test compute function
items = [
{"pred": "a", "gold": "a"},
{"pred": "b", "gold": "b"},
{"pred": "c", "gold": "d"},
]
result = spec.compute(items)
assert result == [1, 1, 0]
# Test aggregation
agg_result = spec.aggregate(result)
assert agg_result == 2 / 3
def test_metric_without_aggregation(self):
"""Test metric registration without aggregation."""
@register_metric(metric="no_agg", higher_is_better=False)
def compute_something(items):
return [len(item) for item in items]
spec = metric_registry.get("no_agg")
# Should raise NotImplementedError when aggregate is called
with pytest.raises(NotImplementedError) as exc_info:
spec.aggregate([1, 2, 3])
assert "No aggregation function specified" in str(exc_info.value)
def test_get_metric_helper(self):
"""Test get_metric helper function."""
@register_metric(
metric="helper_test",
aggregation="mean", # Assuming 'mean' exists in metric_agg_registry
)
def compute_helper(items):
return items
# get_metric returns just the compute function
compute_fn = get_metric("helper_test")
assert callable(compute_fn)
assert compute_fn([1, 2, 3]) == [1, 2, 3]
class TestRegistryUtilities:
"""Test utility methods."""
def test_freeze(self):
"""Test freezing a registry."""
reg = Registry("test")
# Add some items
reg.register("item1", lazy="os:getcwd")
reg.register("item2", lazy="os:getenv")
# Freeze the registry
reg.freeze()
# Should not be able to register new items
with pytest.raises(TypeError):
reg._objs["new"] = "value"
# Should still be able to access items
assert "item1" in reg
assert callable(reg.get("item1"))
def test_clear(self):
"""Test clearing a registry."""
reg = Registry("test")
# Add items
reg.register("item1", lazy="os:getcwd")
reg.register("item2", lazy="os:getenv")
assert len(reg) == 2
# Clear
reg._clear()
assert len(reg) == 0
assert list(reg) == []
def test_origin(self):
"""Test origin tracking."""
reg = Registry("test")
# Lazy entry - no origin
reg.register("lazy", lazy="os:getcwd")
assert reg.origin("lazy") is None
# Concrete class - should have origin
@reg.register("concrete")
class ConcreteClass:
pass
origin = reg.origin("concrete")
assert origin is not None
assert "test_registry.py" in origin
assert ":" in origin # Has line number
class TestBackwardCompatibility:
"""Test backward compatibility features."""
def test_model_registry_alias(self):
"""Test MODEL_REGISTRY backward compatibility."""
from lm_eval.api.registry import MODEL_REGISTRY
# Should be same object as model_registry
assert MODEL_REGISTRY is model_registry
# Should reflect current state
before_count = len(MODEL_REGISTRY)
# Add new model
@model_registry.register("test_model_compat")
class TestModelCompat(LM):
pass
# MODEL_REGISTRY should immediately reflect the change
assert len(MODEL_REGISTRY) == before_count + 1
assert "test_model_compat" in MODEL_REGISTRY
def test_legacy_functions(self):
"""Test legacy helper functions."""
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_model,
register_model,
)
# register_model should work
@register_model("legacy_model")
class LegacyModel(LM):
pass
# get_model should work
assert get_model("legacy_model") == LegacyModel
# Check other aliases
assert DEFAULT_METRIC_REGISTRY is metric_registry
assert AGGREGATION_REGISTRY is metric_agg_registry
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_invalid_lazy_format(self):
"""Test error on invalid lazy format."""
reg = Registry("test")
reg.register("bad", lazy="no_colon_here")
with pytest.raises(ValueError) as exc_info:
reg.get("bad")
assert "expected 'module:object'" in str(exc_info.value)
def test_lazy_module_not_found(self):
"""Test error when lazy module doesn't exist."""
reg = Registry("test")
reg.register("missing", lazy="nonexistent_module:Class")
with pytest.raises(ModuleNotFoundError):
reg.get("missing")
def test_lazy_attribute_not_found(self):
"""Test error when lazy attribute doesn't exist."""
reg = Registry("test")
reg.register("missing_attr", lazy="os:nonexistent_function")
with pytest.raises(AttributeError):
reg.get("missing_attr")
def test_multiple_aliases_with_lazy(self):
"""Test that multiple aliases with lazy fails."""
reg = Registry("test")
with pytest.raises(ValueError) as exc_info:
reg.register("alias1", "alias2", lazy="os:getcwd")
assert "Exactly one alias required" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment