"vscode:/vscode.git/clone" did not exist on "6c68b56b4678a246d7abea544c2c53c4ecb15814"
Commit 2b32f7be authored by Baber's avatar Baber
Browse files

add tests

parent 124d3049
import logging import logging
import os import os
from .api import metrics, registry # initializes the registries
from .filters import *
__version__ = "0.4.9.1" __version__ = "0.4.9.1"
......
from functools import partial from functools import partial
from typing import List
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter from lm_eval.api.registry import filter_registry, get_filter
from . import custom, extraction, selection, transformation from . import custom, extraction, selection, transformation
def build_filter_ensemble( def build_filter_ensemble(
filter_name: str, components: List[List[str]] filter_name: str, components: list[list[str]]
) -> FilterEnsemble: ) -> FilterEnsemble:
""" """
Create a filtering pipeline. Create a filtering pipeline.
...@@ -23,3 +22,12 @@ def build_filter_ensemble( ...@@ -23,3 +22,12 @@ def build_filter_ensemble(
filters.append(f) filters.append(f)
return FilterEnsemble(name=filter_name, filters=filters) return FilterEnsemble(name=filter_name, filters=filters)
__all__ = [
"custom",
"extraction",
"selection",
"transformation",
"build_filter_ensemble",
]
...@@ -108,14 +108,14 @@ plugins.md029.allow_extended_start_values = true # ol-prefix ...@@ -108,14 +108,14 @@ plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled = false # no-bare-urls plugins.md034.enabled = false # no-bare-urls
[tool.ruff.lint] [tool.ruff.lint]
extend-select = ["I", "W605"] extend-select = ["I", "W605", "UP"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores] [tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403"] "__init__.py" = ["F401","F402","F403","F405"]
"utils.py" = ["F401"] "utils.py" = ["F401"]
[dependency-groups] [dependency-groups]
......
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""
import threading
import pytest
from lm_eval.api.model import LM
from lm_eval.api.registry import (
MetricSpec,
Registry,
get_metric,
metric_agg_registry,
metric_registry,
model_registry,
register_metric,
)
# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics
class TestBasicRegistry:
"""Test basic registry functionality."""
def test_create_registry(self):
"""Test creating a basic registry."""
reg = Registry("test")
assert len(reg) == 0
assert list(reg) == []
def test_decorator_registration(self):
"""Test decorator-based registration."""
reg = Registry("test")
@reg.register("my_class")
class MyClass:
pass
assert "my_class" in reg
assert reg.get("my_class") == MyClass
assert reg["my_class"] == MyClass
def test_decorator_multiple_aliases(self):
"""Test decorator with multiple aliases."""
reg = Registry("test")
@reg.register("alias1", "alias2", "alias3")
class MyClass:
pass
assert reg.get("alias1") == MyClass
assert reg.get("alias2") == MyClass
assert reg.get("alias3") == MyClass
def test_decorator_auto_name(self):
"""Test decorator using class name when no alias provided."""
reg = Registry("test")
@reg.register()
class AutoNamedClass:
pass
assert reg.get("AutoNamedClass") == AutoNamedClass
def test_lazy_registration(self):
"""Test lazy loading with module paths."""
reg = Registry("test")
# Register with lazy loading
reg.register("join", lazy="os.path:join")
# Check it's stored as a string
assert isinstance(reg._objs["join"], str)
# Access triggers materialization
result = reg.get("join")
import os
assert result == os.path.join
assert callable(result)
def test_direct_registration(self):
"""Test direct object registration."""
reg = Registry("test")
class DirectClass:
pass
obj = DirectClass()
reg.register("direct", lazy=obj)
assert reg.get("direct") == obj
def test_metadata_removed(self):
"""Test that metadata parameter is removed from generic registry."""
reg = Registry("test")
# Should work without metadata parameter
@reg.register("test_class")
class TestClass:
pass
assert "test_class" in reg
assert reg.get("test_class") == TestClass
def test_unknown_key_error(self):
"""Test error when accessing unknown key."""
reg = Registry("test")
with pytest.raises(KeyError) as exc_info:
reg.get("unknown")
assert "Unknown test 'unknown'" in str(exc_info.value)
assert "Available:" in str(exc_info.value)
def test_iteration(self):
"""Test registry iteration."""
reg = Registry("test")
reg.register("a", lazy="os:getcwd")
reg.register("b", lazy="os:getenv")
reg.register("c", lazy="os:getpid")
assert list(reg) == ["a", "b", "c"]
assert len(reg) == 3
# Test items()
items = list(reg.items())
assert len(items) == 3
assert items[0][0] == "a"
assert isinstance(items[0][1], str) # Still lazy
def test_mapping_protocol(self):
"""Test that registry implements mapping protocol."""
reg = Registry("test")
reg.register("test", lazy="os:getcwd")
# __getitem__
assert reg["test"] == reg.get("test")
# __contains__
assert "test" in reg
assert "missing" not in reg
# __iter__ and __len__ tested above
class TestTypeConstraints:
"""Test type checking and base class constraints."""
def test_base_class_constraint(self):
"""Test base class validation."""
# Define a base class
class BaseClass:
pass
class GoodSubclass(BaseClass):
pass
class BadClass:
pass
reg = Registry("typed", base_cls=BaseClass)
# Should work - correct subclass
@reg.register("good")
class GoodInline(BaseClass):
pass
# Should fail - wrong type
with pytest.raises(TypeError) as exc_info:
@reg.register("bad")
class BadInline:
pass
assert "must inherit from" in str(exc_info.value)
def test_lazy_type_check(self):
"""Test that type checking happens on materialization for lazy entries."""
class BaseClass:
pass
reg = Registry("typed", base_cls=BaseClass)
# Register a lazy entry that will fail type check
reg.register("bad_lazy", lazy="os.path:join")
# Should fail when accessed - the error message varies
with pytest.raises(TypeError):
reg.get("bad_lazy")
class TestCollisionHandling:
"""Test registration collision scenarios."""
def test_identical_registration(self):
"""Test that identical re-registration is allowed."""
reg = Registry("test")
class MyClass:
pass
# First registration
reg.register("test", lazy=MyClass)
# Identical re-registration should work
reg.register("test", lazy=MyClass)
assert reg.get("test") == MyClass
def test_different_registration_fails(self):
"""Test that different re-registration fails."""
reg = Registry("test")
class Class1:
pass
class Class2:
pass
reg.register("test", lazy=Class1)
with pytest.raises(ValueError) as exc_info:
reg.register("test", lazy=Class2)
assert "already registered" in str(exc_info.value)
def test_lazy_to_concrete_upgrade(self):
"""Test that lazy placeholder can be upgraded to concrete class."""
reg = Registry("test")
# Register lazy
reg.register("myclass", lazy="test_registry:MyUpgradeClass")
# Define and register concrete - should work
@reg.register("myclass")
class MyUpgradeClass:
pass
assert reg.get("myclass") == MyUpgradeClass
class TestThreadSafety:
"""Test thread safety of registry operations."""
def test_concurrent_access(self):
"""Test concurrent access to lazy entries."""
reg = Registry("test")
# Register lazy entry
reg.register("concurrent", lazy="os.path:join")
results = []
errors = []
def access_item():
try:
result = reg.get("concurrent")
results.append(result)
except Exception as e:
errors.append(str(e))
# Launch threads
threads = []
for _ in range(10):
t = threading.Thread(target=access_item)
threads.append(t)
t.start()
# Wait for completion
for t in threads:
t.join()
# Check results
assert len(errors) == 0
assert len(results) == 10
# All should get the same object
assert all(r == results[0] for r in results)
def test_concurrent_registration(self):
"""Test concurrent registration doesn't cause issues."""
reg = Registry("test")
errors = []
def register_item(name, value):
try:
reg.register(name, lazy=value)
except Exception as e:
errors.append(str(e))
# Launch threads with different registrations
threads = []
for i in range(10):
t = threading.Thread(
target=register_item, args=(f"item_{i}", f"module{i}:Class{i}")
)
threads.append(t)
t.start()
# Wait for completion
for t in threads:
t.join()
# Check results
assert len(errors) == 0
assert len(reg) == 10
class TestMetricRegistry:
"""Test metric-specific registry functionality."""
def test_metric_spec(self):
"""Test MetricSpec dataclass."""
def compute_fn(items):
return [1 for _ in items]
def agg_fn(values):
return sum(values) / len(values)
spec = MetricSpec(
compute=compute_fn,
aggregate=agg_fn,
higher_is_better=True,
output_type="probability",
)
assert spec.compute == compute_fn
assert spec.aggregate == agg_fn
assert spec.higher_is_better
assert spec.output_type == "probability"
def test_register_metric_decorator(self):
"""Test @register_metric decorator."""
# Register aggregation function first
@metric_agg_registry.register("test_mean")
def test_mean(values):
return sum(values) / len(values) if values else 0.0
# Register metric
@register_metric(
metric="test_accuracy",
aggregation="test_mean",
higher_is_better=True,
output_type="accuracy",
)
def compute_accuracy(items):
return [1 if item["pred"] == item["gold"] else 0 for item in items]
# Check registration
assert "test_accuracy" in metric_registry
spec = metric_registry.get("test_accuracy")
assert isinstance(spec, MetricSpec)
assert spec.higher_is_better
assert spec.output_type == "accuracy"
# Test compute function
items = [
{"pred": "a", "gold": "a"},
{"pred": "b", "gold": "b"},
{"pred": "c", "gold": "d"},
]
result = spec.compute(items)
assert result == [1, 1, 0]
# Test aggregation
agg_result = spec.aggregate(result)
assert agg_result == 2 / 3
def test_metric_without_aggregation(self):
"""Test metric registration without aggregation."""
@register_metric(metric="no_agg", higher_is_better=False)
def compute_something(items):
return [len(item) for item in items]
spec = metric_registry.get("no_agg")
# Should raise NotImplementedError when aggregate is called
with pytest.raises(NotImplementedError) as exc_info:
spec.aggregate([1, 2, 3])
assert "No aggregation function specified" in str(exc_info.value)
def test_get_metric_helper(self):
"""Test get_metric helper function."""
@register_metric(
metric="helper_test",
aggregation="mean", # Assuming 'mean' exists in metric_agg_registry
)
def compute_helper(items):
return items
# get_metric returns just the compute function
compute_fn = get_metric("helper_test")
assert callable(compute_fn)
assert compute_fn([1, 2, 3]) == [1, 2, 3]
class TestRegistryUtilities:
"""Test utility methods."""
def test_freeze(self):
"""Test freezing a registry."""
reg = Registry("test")
# Add some items
reg.register("item1", lazy="os:getcwd")
reg.register("item2", lazy="os:getenv")
# Freeze the registry
reg.freeze()
# Should not be able to register new items
with pytest.raises(TypeError):
reg._objs["new"] = "value"
# Should still be able to access items
assert "item1" in reg
assert callable(reg.get("item1"))
def test_clear(self):
"""Test clearing a registry."""
reg = Registry("test")
# Add items
reg.register("item1", lazy="os:getcwd")
reg.register("item2", lazy="os:getenv")
assert len(reg) == 2
# Clear
reg._clear()
assert len(reg) == 0
assert list(reg) == []
def test_origin(self):
"""Test origin tracking."""
reg = Registry("test")
# Lazy entry - no origin
reg.register("lazy", lazy="os:getcwd")
assert reg.origin("lazy") is None
# Concrete class - should have origin
@reg.register("concrete")
class ConcreteClass:
pass
origin = reg.origin("concrete")
assert origin is not None
assert "test_registry.py" in origin
assert ":" in origin # Has line number
class TestBackwardCompatibility:
"""Test backward compatibility features."""
def test_model_registry_alias(self):
"""Test MODEL_REGISTRY backward compatibility."""
from lm_eval.api.registry import MODEL_REGISTRY
# Should be same object as model_registry
assert MODEL_REGISTRY is model_registry
# Should reflect current state
before_count = len(MODEL_REGISTRY)
# Add new model
@model_registry.register("test_model_compat")
class TestModelCompat(LM):
pass
# MODEL_REGISTRY should immediately reflect the change
assert len(MODEL_REGISTRY) == before_count + 1
assert "test_model_compat" in MODEL_REGISTRY
def test_legacy_functions(self):
"""Test legacy helper functions."""
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_model,
register_model,
)
# register_model should work
@register_model("legacy_model")
class LegacyModel(LM):
pass
# get_model should work
assert get_model("legacy_model") == LegacyModel
# Check other aliases
assert DEFAULT_METRIC_REGISTRY is metric_registry
assert AGGREGATION_REGISTRY is metric_agg_registry
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_invalid_lazy_format(self):
"""Test error on invalid lazy format."""
reg = Registry("test")
reg.register("bad", lazy="no_colon_here")
with pytest.raises(ValueError) as exc_info:
reg.get("bad")
assert "expected 'module:object'" in str(exc_info.value)
def test_lazy_module_not_found(self):
"""Test error when lazy module doesn't exist."""
reg = Registry("test")
reg.register("missing", lazy="nonexistent_module:Class")
with pytest.raises(ModuleNotFoundError):
reg.get("missing")
def test_lazy_attribute_not_found(self):
"""Test error when lazy attribute doesn't exist."""
reg = Registry("test")
reg.register("missing_attr", lazy="os:nonexistent_function")
with pytest.raises(AttributeError):
reg.get("missing_attr")
def test_multiple_aliases_with_lazy(self):
"""Test that multiple aliases with lazy fails."""
reg = Registry("test")
with pytest.raises(ValueError) as exc_info:
reg.register("alias1", "alias2", lazy="os:getcwd")
assert "Exactly one alias required" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment