Unverified Commit 70314843 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Merge pull request #3189 from EleutherAI/lazy_reg

refactor registry
parents 73202a2e 930b4253
from .api import metrics, model, registry # initializes the registries
from .filters import *
__version__ = "0.4.9.1"
......
This diff is collapsed.
from __future__ import annotations
from functools import partial
from typing import Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
from lm_eval.api.registry import filter_registry, get_filter
from . import custom, extraction, selection, transformation
def build_filter_ensemble(
filter_name: str,
components: list[tuple[str, Optional[dict[str, Union[str, int, float]]]]],
components: list[tuple[str, dict[str, str | int | float] | None]],
) -> FilterEnsemble:
"""
Create a filtering pipeline.
......@@ -21,3 +21,12 @@ def build_filter_ensemble(
partial(get_filter(func), **(kwargs or {})) for func, kwargs in components
],
)
__all__ = [
"custom",
"extraction",
"selection",
"transformation",
"build_filter_ensemble",
]
from . import (
anthropic_llms,
api_models,
dummy,
gguf,
hf_audiolm,
hf_steered,
hf_vlms,
huggingface,
ibm_watsonx_ai,
mamba_lm,
nemo_lm,
neuron_optimum,
openai_completions,
optimum_ipex,
optimum_lm,
sglang_causallms,
sglang_generate_API,
textsynth,
vllm_causallms,
vllm_vlms,
)
# TODO: implement __all__
# Models are now lazily loaded via the registry system
# No need to import them all at once
# Define model mappings for lazy registration
MODEL_MAPPING = {
"anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM",
"anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM",
"anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM",
"local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI",
"local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion",
"openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI",
"openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion",
"dummy": "lm_eval.models.dummy:DummyLM",
"gguf": "lm_eval.models.gguf:GGUFLM",
"ggml": "lm_eval.models.gguf:GGUFLM",
"hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM",
"steered": "lm_eval.models.hf_steered:SteeredHF",
"hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM",
"hf-auto": "lm_eval.models.huggingface:HFLM",
"hf": "lm_eval.models.huggingface:HFLM",
"huggingface": "lm_eval.models.huggingface:HFLM",
"watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI",
"mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper",
"nemo_lm": "lm_eval.models.nemo_lm:NeMoLM",
"neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM",
"ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM",
"openvino": "lm_eval.models.optimum_lm:OptimumLM",
"sglang": "lm_eval.models.sglang_causallms:SGLANG",
"sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI",
"textsynth": "lm_eval.models.textsynth:TextSynthLM",
"vllm": "lm_eval.models.vllm_causallms:VLLM",
"vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM",
}
# Register all models lazily
def _register_all_models():
"""Register all known models lazily in the registry."""
from lm_eval.api.registry import model_registry
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Register the lazy placeholder using lazy parameter
model_registry.register(name, lazy=path)
# Call registration on module import
_register_all_models()
__all__ = ["MODEL_MAPPING"]
try:
......
from collections.abc import Generator
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Callable, Generator, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from peft.peft_model import PeftModel
......
......@@ -3,7 +3,7 @@ import json
import logging
import os
import warnings
from functools import lru_cache
from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
......@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg)
@lru_cache(maxsize=None)
@cache
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
......
......@@ -42,7 +42,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError:
pass
print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING:
pass
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -121,7 +121,7 @@ lint.fixable = ["I001", "F401", "UP"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117", "E741"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401", "F402", "F403"]
"__init__.py" = ["F401", "F402", "F403", "F405"]
[tool.ruff.lint.isort]
combine-as-imports = true
......
......@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from tqdm import tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registryv2 import ALL_TASKS
eval_logger = logging.getLogger(__name__)
......
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""
import threading
import pytest
from lm_eval.api.model import LM
from lm_eval.api.registry import (
MetricSpec,
Registry,
get_metric,
metric_agg_registry,
metric_registry,
model_registry,
register_metric,
)
# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics
class TestBasicRegistry:
"""Test basic registry functionality."""
def test_create_registry(self):
"""Test creating a basic registry."""
reg = Registry("test")
assert len(reg) == 0
assert list(reg) == []
def test_decorator_registration(self):
"""Test decorator-based registration."""
reg = Registry("test")
@reg.register("my_class")
class MyClass:
pass
assert "my_class" in reg
assert reg.get("my_class") == MyClass
assert reg["my_class"] == MyClass
def test_decorator_multiple_aliases(self):
"""Test decorator with multiple aliases."""
reg = Registry("test")
@reg.register("alias1", "alias2", "alias3")
class MyClass:
pass
assert reg.get("alias1") == MyClass
assert reg.get("alias2") == MyClass
assert reg.get("alias3") == MyClass
def test_decorator_auto_name(self):
"""Test decorator using class name when no alias provided."""
reg = Registry("test")
@reg.register()
class AutoNamedClass:
pass
assert reg.get("AutoNamedClass") == AutoNamedClass
def test_lazy_registration(self):
"""Test lazy loading with module paths."""
reg = Registry("test")
# Register with lazy loading
reg.register("join", lazy="os.path:join")
# Check it's stored as a string
assert isinstance(reg._objs["join"], str)
# Access triggers materialization
result = reg.get("join")
import os
assert result == os.path.join
assert callable(result)
def test_direct_registration(self):
"""Test direct object registration."""
reg = Registry("test")
class DirectClass:
pass
obj = DirectClass()
reg.register("direct", lazy=obj)
assert reg.get("direct") == obj
def test_metadata_removed(self):
"""Test that metadata parameter is removed from generic registry."""
reg = Registry("test")
# Should work without metadata parameter
@reg.register("test_class")
class TestClass:
pass
assert "test_class" in reg
assert reg.get("test_class") == TestClass
def test_unknown_key_error(self):
"""Test error when accessing unknown key."""
reg = Registry("test")
with pytest.raises(KeyError) as exc_info:
reg.get("unknown")
assert "Unknown test 'unknown'" in str(exc_info.value)
assert "Available:" in str(exc_info.value)
def test_iteration(self):
"""Test registry iteration."""
reg = Registry("test")
reg.register("a", lazy="os:getcwd")
reg.register("b", lazy="os:getenv")
reg.register("c", lazy="os:getpid")
assert list(reg) == ["a", "b", "c"]
assert len(reg) == 3
# Test items()
items = list(reg.items())
assert len(items) == 3
assert items[0][0] == "a"
assert isinstance(items[0][1], str) # Still lazy
def test_mapping_protocol(self):
"""Test that registry implements mapping protocol."""
reg = Registry("test")
reg.register("test", lazy="os:getcwd")
# __getitem__
assert reg["test"] == reg.get("test")
# __contains__
assert "test" in reg
assert "missing" not in reg
# __iter__ and __len__ tested above
class TestTypeConstraints:
"""Test type checking and base class constraints."""
def test_base_class_constraint(self):
"""Test base class validation."""
# Define a base class
class BaseClass:
pass
class GoodSubclass(BaseClass):
pass
class BadClass:
pass
reg = Registry("typed", base_cls=BaseClass)
# Should work - correct subclass
@reg.register("good")
class GoodInline(BaseClass):
pass
# Should fail - wrong type
with pytest.raises(TypeError) as exc_info:
@reg.register("bad")
class BadInline:
pass
assert "must inherit from" in str(exc_info.value)
def test_lazy_type_check(self):
"""Test that type checking happens on materialization for lazy entries."""
class BaseClass:
pass
reg = Registry("typed", base_cls=BaseClass)
# Register a lazy entry that will fail type check
reg.register("bad_lazy", lazy="os.path:join")
# Should fail when accessed - the error message varies
with pytest.raises(TypeError):
reg.get("bad_lazy")
class TestCollisionHandling:
"""Test registration collision scenarios."""
def test_identical_registration(self):
"""Test that identical re-registration is allowed."""
reg = Registry("test")
class MyClass:
pass
# First registration
reg.register("test", lazy=MyClass)
# Identical re-registration should work
reg.register("test", lazy=MyClass)
assert reg.get("test") == MyClass
def test_different_registration_fails(self):
"""Test that different re-registration fails."""
reg = Registry("test")
class Class1:
pass
class Class2:
pass
reg.register("test", lazy=Class1)
with pytest.raises(ValueError) as exc_info:
reg.register("test", lazy=Class2)
assert "already registered" in str(exc_info.value)
def test_lazy_to_concrete_upgrade(self):
"""Test that lazy placeholder can be upgraded to concrete class."""
reg = Registry("test")
# Register lazy
reg.register("myclass", lazy="test_registry:MyUpgradeClass")
# Define and register concrete - should work
@reg.register("myclass")
class MyUpgradeClass:
pass
assert reg.get("myclass") == MyUpgradeClass
class TestThreadSafety:
"""Test thread safety of registry operations."""
def test_concurrent_access(self):
"""Test concurrent access to lazy entries."""
reg = Registry("test")
# Register lazy entry
reg.register("concurrent", lazy="os.path:join")
results = []
errors = []
def access_item():
try:
result = reg.get("concurrent")
results.append(result)
except Exception as e:
errors.append(str(e))
# Launch threads
threads = []
for _ in range(10):
t = threading.Thread(target=access_item)
threads.append(t)
t.start()
# Wait for completion
for t in threads:
t.join()
# Check results
assert len(errors) == 0
assert len(results) == 10
# All should get the same object
assert all(r == results[0] for r in results)
def test_concurrent_registration(self):
"""Test concurrent registration doesn't cause issues."""
reg = Registry("test")
errors = []
def register_item(name, value):
try:
reg.register(name, lazy=value)
except Exception as e:
errors.append(str(e))
# Launch threads with different registrations
threads = []
for i in range(10):
t = threading.Thread(
target=register_item, args=(f"item_{i}", f"module{i}:Class{i}")
)
threads.append(t)
t.start()
# Wait for completion
for t in threads:
t.join()
# Check results
assert len(errors) == 0
assert len(reg) == 10
class TestMetricRegistry:
"""Test metric-specific registry functionality."""
def test_metric_spec(self):
"""Test MetricSpec dataclass."""
def compute_fn(items):
return [1 for _ in items]
def agg_fn(values):
return sum(values) / len(values)
spec = MetricSpec(
compute=compute_fn,
aggregate=agg_fn,
higher_is_better=True,
output_type="probability",
)
assert spec.compute == compute_fn
assert spec.aggregate == agg_fn
assert spec.higher_is_better
assert spec.output_type == "probability"
def test_register_metric_decorator(self):
"""Test @register_metric decorator."""
# Register aggregation function first
@metric_agg_registry.register("test_mean")
def test_mean(values):
return sum(values) / len(values) if values else 0.0
# Register metric
@register_metric(
metric="test_accuracy",
aggregation="test_mean",
higher_is_better=True,
output_type="accuracy",
)
def compute_accuracy(items):
return [1 if item["pred"] == item["gold"] else 0 for item in items]
# Check registration
assert "test_accuracy" in metric_registry
spec = metric_registry.get("test_accuracy")
assert isinstance(spec, MetricSpec)
assert spec.higher_is_better
assert spec.output_type == "accuracy"
# Test compute function
items = [
{"pred": "a", "gold": "a"},
{"pred": "b", "gold": "b"},
{"pred": "c", "gold": "d"},
]
result = spec.compute(items)
assert result == [1, 1, 0]
# Test aggregation
agg_result = spec.aggregate(result)
assert agg_result == 2 / 3
def test_metric_without_aggregation(self):
"""Test metric registration without aggregation."""
@register_metric(metric="no_agg", higher_is_better=False)
def compute_something(items):
return [len(item) for item in items]
spec = metric_registry.get("no_agg")
# Should raise NotImplementedError when aggregate is called
with pytest.raises(NotImplementedError) as exc_info:
spec.aggregate([1, 2, 3])
assert "No aggregation function specified" in str(exc_info.value)
def test_get_metric_helper(self):
"""Test get_metric helper function."""
@register_metric(
metric="helper_test",
aggregation="mean", # Assuming 'mean' exists in metric_agg_registry
)
def compute_helper(items):
return items
# get_metric returns just the compute function
compute_fn = get_metric("helper_test")
assert callable(compute_fn)
assert compute_fn([1, 2, 3]) == [1, 2, 3]
class TestRegistryUtilities:
"""Test utility methods."""
def test_freeze(self):
"""Test freezing a registry."""
reg = Registry("test")
# Add some items
reg.register("item1", lazy="os:getcwd")
reg.register("item2", lazy="os:getenv")
# Freeze the registry
reg.freeze()
# Should not be able to register new items
with pytest.raises(TypeError):
reg._objs["new"] = "value"
# Should still be able to access items
assert "item1" in reg
assert callable(reg.get("item1"))
def test_clear(self):
"""Test clearing a registry."""
reg = Registry("test")
# Add items
reg.register("item1", lazy="os:getcwd")
reg.register("item2", lazy="os:getenv")
assert len(reg) == 2
# Clear
reg._clear()
assert len(reg) == 0
assert list(reg) == []
def test_origin(self):
"""Test origin tracking."""
reg = Registry("test")
# Lazy entry - no origin
reg.register("lazy", lazy="os:getcwd")
assert reg.origin("lazy") is None
# Concrete class - should have origin
@reg.register("concrete")
class ConcreteClass:
pass
origin = reg.origin("concrete")
assert origin is not None
assert "test_registry.py" in origin
assert ":" in origin # Has line number
class TestBackwardCompatibility:
"""Test backward compatibility features."""
def test_model_registry_alias(self):
"""Test MODEL_REGISTRY backward compatibility."""
from lm_eval.api.registry import MODEL_REGISTRY
# Should be same object as model_registry
assert MODEL_REGISTRY is model_registry
# Should reflect current state
before_count = len(MODEL_REGISTRY)
# Add new model
@model_registry.register("test_model_compat")
class TestModelCompat(LM):
pass
# MODEL_REGISTRY should immediately reflect the change
assert len(MODEL_REGISTRY) == before_count + 1
assert "test_model_compat" in MODEL_REGISTRY
def test_legacy_functions(self):
"""Test legacy helper functions."""
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_model,
register_model,
)
# register_model should work
@register_model("legacy_model")
class LegacyModel(LM):
pass
# get_model should work
assert get_model("legacy_model") == LegacyModel
# Check other aliases
assert DEFAULT_METRIC_REGISTRY is metric_registry
assert AGGREGATION_REGISTRY is metric_agg_registry
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_invalid_lazy_format(self):
"""Test error on invalid lazy format."""
reg = Registry("test")
reg.register("bad", lazy="no_colon_here")
with pytest.raises(ValueError) as exc_info:
reg.get("bad")
assert "expected 'module:object'" in str(exc_info.value)
def test_lazy_module_not_found(self):
"""Test error when lazy module doesn't exist."""
reg = Registry("test")
reg.register("missing", lazy="nonexistent_module:Class")
with pytest.raises(ModuleNotFoundError):
reg.get("missing")
def test_lazy_attribute_not_found(self):
"""Test error when lazy attribute doesn't exist."""
reg = Registry("test")
reg.register("missing_attr", lazy="os:nonexistent_function")
with pytest.raises(AttributeError):
reg.get("missing_attr")
def test_multiple_aliases_with_lazy(self):
"""Test that multiple aliases with lazy fails."""
reg = Registry("test")
with pytest.raises(ValueError) as exc_info:
reg.register("alias1", "alias2", lazy="os:getcwd")
assert "Exactly one alias required" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment