Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
2b32f7be
Commit
2b32f7be
authored
Jul 28, 2025
by
Baber
Browse files
add tests
parent
124d3049
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
569 additions
and
5 deletions
+569
-5
lm_eval/__init__.py
lm_eval/__init__.py
+3
-0
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+11
-3
pyproject.toml
pyproject.toml
+2
-2
test_registry.py
test_registry.py
+553
-0
No files found.
lm_eval/__init__.py
View file @
2b32f7be
import
logging
import
os
from
.api
import
metrics
,
registry
# initializes the registries
from
.filters
import
*
__version__
=
"0.4.9.1"
...
...
lm_eval/filters/__init__.py
View file @
2b32f7be
from
functools
import
partial
from
typing
import
List
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
from
lm_eval.api.registry
import
filter_registry
,
get_filter
from
.
import
custom
,
extraction
,
selection
,
transformation
def
build_filter_ensemble
(
filter_name
:
str
,
components
:
L
ist
[
L
ist
[
str
]]
filter_name
:
str
,
components
:
l
ist
[
l
ist
[
str
]]
)
->
FilterEnsemble
:
"""
Create a filtering pipeline.
...
...
@@ -23,3 +22,12 @@ def build_filter_ensemble(
filters
.
append
(
f
)
return
FilterEnsemble
(
name
=
filter_name
,
filters
=
filters
)
__all__
=
[
"custom"
,
"extraction"
,
"selection"
,
"transformation"
,
"build_filter_ensemble"
,
]
pyproject.toml
View file @
2b32f7be
...
...
@@ -108,14 +108,14 @@ plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled
=
false
# no-bare-urls
[tool.ruff.lint]
extend-select
=
[
"I"
,
"W605"
]
extend-select
=
[
"I"
,
"W605"
,
"UP"
]
[tool.ruff.lint.isort]
lines-after-imports
=
2
known-first-party
=
["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
["F401","F402","F403"]
"__init__.py"
=
["F401","F402","F403"
,"F405"
]
"utils.py"
=
["F401"]
[dependency-groups]
...
...
test_registry.py
0 → 100644
View file @
2b32f7be
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""
import
threading
import
pytest
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
(
MetricSpec
,
Registry
,
get_metric
,
metric_agg_registry
,
metric_registry
,
model_registry
,
register_metric
,
)
# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics
class
TestBasicRegistry
:
"""Test basic registry functionality."""
def
test_create_registry
(
self
):
"""Test creating a basic registry."""
reg
=
Registry
(
"test"
)
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_decorator_registration
(
self
):
"""Test decorator-based registration."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"my_class"
)
class
MyClass
:
pass
assert
"my_class"
in
reg
assert
reg
.
get
(
"my_class"
)
==
MyClass
assert
reg
[
"my_class"
]
==
MyClass
def
test_decorator_multiple_aliases
(
self
):
"""Test decorator with multiple aliases."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"alias1"
,
"alias2"
,
"alias3"
)
class
MyClass
:
pass
assert
reg
.
get
(
"alias1"
)
==
MyClass
assert
reg
.
get
(
"alias2"
)
==
MyClass
assert
reg
.
get
(
"alias3"
)
==
MyClass
def
test_decorator_auto_name
(
self
):
"""Test decorator using class name when no alias provided."""
reg
=
Registry
(
"test"
)
@
reg
.
register
()
class
AutoNamedClass
:
pass
assert
reg
.
get
(
"AutoNamedClass"
)
==
AutoNamedClass
def
test_lazy_registration
(
self
):
"""Test lazy loading with module paths."""
reg
=
Registry
(
"test"
)
# Register with lazy loading
reg
.
register
(
"join"
,
lazy
=
"os.path:join"
)
# Check it's stored as a string
assert
isinstance
(
reg
.
_objs
[
"join"
],
str
)
# Access triggers materialization
result
=
reg
.
get
(
"join"
)
import
os
assert
result
==
os
.
path
.
join
assert
callable
(
result
)
def
test_direct_registration
(
self
):
"""Test direct object registration."""
reg
=
Registry
(
"test"
)
class
DirectClass
:
pass
obj
=
DirectClass
()
reg
.
register
(
"direct"
,
lazy
=
obj
)
assert
reg
.
get
(
"direct"
)
==
obj
def
test_metadata_removed
(
self
):
"""Test that metadata parameter is removed from generic registry."""
reg
=
Registry
(
"test"
)
# Should work without metadata parameter
@
reg
.
register
(
"test_class"
)
class
TestClass
:
pass
assert
"test_class"
in
reg
assert
reg
.
get
(
"test_class"
)
==
TestClass
def
test_unknown_key_error
(
self
):
"""Test error when accessing unknown key."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
KeyError
)
as
exc_info
:
reg
.
get
(
"unknown"
)
assert
"Unknown test 'unknown'"
in
str
(
exc_info
.
value
)
assert
"Available:"
in
str
(
exc_info
.
value
)
def
test_iteration
(
self
):
"""Test registry iteration."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"a"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"b"
,
lazy
=
"os:getenv"
)
reg
.
register
(
"c"
,
lazy
=
"os:getpid"
)
assert
list
(
reg
)
==
[
"a"
,
"b"
,
"c"
]
assert
len
(
reg
)
==
3
# Test items()
items
=
list
(
reg
.
items
())
assert
len
(
items
)
==
3
assert
items
[
0
][
0
]
==
"a"
assert
isinstance
(
items
[
0
][
1
],
str
)
# Still lazy
def
test_mapping_protocol
(
self
):
"""Test that registry implements mapping protocol."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"test"
,
lazy
=
"os:getcwd"
)
# __getitem__
assert
reg
[
"test"
]
==
reg
.
get
(
"test"
)
# __contains__
assert
"test"
in
reg
assert
"missing"
not
in
reg
# __iter__ and __len__ tested above
class
TestTypeConstraints
:
"""Test type checking and base class constraints."""
def
test_base_class_constraint
(
self
):
"""Test base class validation."""
# Define a base class
class
BaseClass
:
pass
class
GoodSubclass
(
BaseClass
):
pass
class
BadClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Should work - correct subclass
@
reg
.
register
(
"good"
)
class
GoodInline
(
BaseClass
):
pass
# Should fail - wrong type
with
pytest
.
raises
(
TypeError
)
as
exc_info
:
@
reg
.
register
(
"bad"
)
class
BadInline
:
pass
assert
"must inherit from"
in
str
(
exc_info
.
value
)
def
test_lazy_type_check
(
self
):
"""Test that type checking happens on materialization for lazy entries."""
class
BaseClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Register a lazy entry that will fail type check
reg
.
register
(
"bad_lazy"
,
lazy
=
"os.path:join"
)
# Should fail when accessed - the error message varies
with
pytest
.
raises
(
TypeError
):
reg
.
get
(
"bad_lazy"
)
class
TestCollisionHandling
:
"""Test registration collision scenarios."""
def
test_identical_registration
(
self
):
"""Test that identical re-registration is allowed."""
reg
=
Registry
(
"test"
)
class
MyClass
:
pass
# First registration
reg
.
register
(
"test"
,
lazy
=
MyClass
)
# Identical re-registration should work
reg
.
register
(
"test"
,
lazy
=
MyClass
)
assert
reg
.
get
(
"test"
)
==
MyClass
def
test_different_registration_fails
(
self
):
"""Test that different re-registration fails."""
reg
=
Registry
(
"test"
)
class
Class1
:
pass
class
Class2
:
pass
reg
.
register
(
"test"
,
lazy
=
Class1
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"test"
,
lazy
=
Class2
)
assert
"already registered"
in
str
(
exc_info
.
value
)
def
test_lazy_to_concrete_upgrade
(
self
):
"""Test that lazy placeholder can be upgraded to concrete class."""
reg
=
Registry
(
"test"
)
# Register lazy
reg
.
register
(
"myclass"
,
lazy
=
"test_registry:MyUpgradeClass"
)
# Define and register concrete - should work
@
reg
.
register
(
"myclass"
)
class
MyUpgradeClass
:
pass
assert
reg
.
get
(
"myclass"
)
==
MyUpgradeClass
class
TestThreadSafety
:
"""Test thread safety of registry operations."""
def
test_concurrent_access
(
self
):
"""Test concurrent access to lazy entries."""
reg
=
Registry
(
"test"
)
# Register lazy entry
reg
.
register
(
"concurrent"
,
lazy
=
"os.path:join"
)
results
=
[]
errors
=
[]
def
access_item
():
try
:
result
=
reg
.
get
(
"concurrent"
)
results
.
append
(
result
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads
threads
=
[]
for
_
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
access_item
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
results
)
==
10
# All should get the same object
assert
all
(
r
==
results
[
0
]
for
r
in
results
)
def
test_concurrent_registration
(
self
):
"""Test concurrent registration doesn't cause issues."""
reg
=
Registry
(
"test"
)
errors
=
[]
def
register_item
(
name
,
value
):
try
:
reg
.
register
(
name
,
lazy
=
value
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads with different registrations
threads
=
[]
for
i
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
register_item
,
args
=
(
f
"item_
{
i
}
"
,
f
"module
{
i
}
:Class
{
i
}
"
)
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
reg
)
==
10
class
TestMetricRegistry
:
"""Test metric-specific registry functionality."""
def
test_metric_spec
(
self
):
"""Test MetricSpec dataclass."""
def
compute_fn
(
items
):
return
[
1
for
_
in
items
]
def
agg_fn
(
values
):
return
sum
(
values
)
/
len
(
values
)
spec
=
MetricSpec
(
compute
=
compute_fn
,
aggregate
=
agg_fn
,
higher_is_better
=
True
,
output_type
=
"probability"
,
)
assert
spec
.
compute
==
compute_fn
assert
spec
.
aggregate
==
agg_fn
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"probability"
def
test_register_metric_decorator
(
self
):
"""Test @register_metric decorator."""
# Register aggregation function first
@
metric_agg_registry
.
register
(
"test_mean"
)
def
test_mean
(
values
):
return
sum
(
values
)
/
len
(
values
)
if
values
else
0.0
# Register metric
@
register_metric
(
metric
=
"test_accuracy"
,
aggregation
=
"test_mean"
,
higher_is_better
=
True
,
output_type
=
"accuracy"
,
)
def
compute_accuracy
(
items
):
return
[
1
if
item
[
"pred"
]
==
item
[
"gold"
]
else
0
for
item
in
items
]
# Check registration
assert
"test_accuracy"
in
metric_registry
spec
=
metric_registry
.
get
(
"test_accuracy"
)
assert
isinstance
(
spec
,
MetricSpec
)
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"accuracy"
# Test compute function
items
=
[
{
"pred"
:
"a"
,
"gold"
:
"a"
},
{
"pred"
:
"b"
,
"gold"
:
"b"
},
{
"pred"
:
"c"
,
"gold"
:
"d"
},
]
result
=
spec
.
compute
(
items
)
assert
result
==
[
1
,
1
,
0
]
# Test aggregation
agg_result
=
spec
.
aggregate
(
result
)
assert
agg_result
==
2
/
3
def
test_metric_without_aggregation
(
self
):
"""Test metric registration without aggregation."""
@
register_metric
(
metric
=
"no_agg"
,
higher_is_better
=
False
)
def
compute_something
(
items
):
return
[
len
(
item
)
for
item
in
items
]
spec
=
metric_registry
.
get
(
"no_agg"
)
# Should raise NotImplementedError when aggregate is called
with
pytest
.
raises
(
NotImplementedError
)
as
exc_info
:
spec
.
aggregate
([
1
,
2
,
3
])
assert
"No aggregation function specified"
in
str
(
exc_info
.
value
)
def
test_get_metric_helper
(
self
):
"""Test get_metric helper function."""
@
register_metric
(
metric
=
"helper_test"
,
aggregation
=
"mean"
,
# Assuming 'mean' exists in metric_agg_registry
)
def
compute_helper
(
items
):
return
items
# get_metric returns just the compute function
compute_fn
=
get_metric
(
"helper_test"
)
assert
callable
(
compute_fn
)
assert
compute_fn
([
1
,
2
,
3
])
==
[
1
,
2
,
3
]
class
TestRegistryUtilities
:
"""Test utility methods."""
def
test_freeze
(
self
):
"""Test freezing a registry."""
reg
=
Registry
(
"test"
)
# Add some items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
# Freeze the registry
reg
.
freeze
()
# Should not be able to register new items
with
pytest
.
raises
(
TypeError
):
reg
.
_objs
[
"new"
]
=
"value"
# Should still be able to access items
assert
"item1"
in
reg
assert
callable
(
reg
.
get
(
"item1"
))
def
test_clear
(
self
):
"""Test clearing a registry."""
reg
=
Registry
(
"test"
)
# Add items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
assert
len
(
reg
)
==
2
# Clear
reg
.
_clear
()
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_origin
(
self
):
"""Test origin tracking."""
reg
=
Registry
(
"test"
)
# Lazy entry - no origin
reg
.
register
(
"lazy"
,
lazy
=
"os:getcwd"
)
assert
reg
.
origin
(
"lazy"
)
is
None
# Concrete class - should have origin
@
reg
.
register
(
"concrete"
)
class
ConcreteClass
:
pass
origin
=
reg
.
origin
(
"concrete"
)
assert
origin
is
not
None
assert
"test_registry.py"
in
origin
assert
":"
in
origin
# Has line number
class
TestBackwardCompatibility
:
"""Test backward compatibility features."""
def
test_model_registry_alias
(
self
):
"""Test MODEL_REGISTRY backward compatibility."""
from
lm_eval.api.registry
import
MODEL_REGISTRY
# Should be same object as model_registry
assert
MODEL_REGISTRY
is
model_registry
# Should reflect current state
before_count
=
len
(
MODEL_REGISTRY
)
# Add new model
@
model_registry
.
register
(
"test_model_compat"
)
class
TestModelCompat
(
LM
):
pass
# MODEL_REGISTRY should immediately reflect the change
assert
len
(
MODEL_REGISTRY
)
==
before_count
+
1
assert
"test_model_compat"
in
MODEL_REGISTRY
def
test_legacy_functions
(
self
):
"""Test legacy helper functions."""
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
get_model
,
register_model
,
)
# register_model should work
@
register_model
(
"legacy_model"
)
class
LegacyModel
(
LM
):
pass
# get_model should work
assert
get_model
(
"legacy_model"
)
==
LegacyModel
# Check other aliases
assert
DEFAULT_METRIC_REGISTRY
is
metric_registry
assert
AGGREGATION_REGISTRY
is
metric_agg_registry
class
TestEdgeCases
:
"""Test edge cases and error conditions."""
def
test_invalid_lazy_format
(
self
):
"""Test error on invalid lazy format."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"bad"
,
lazy
=
"no_colon_here"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
get
(
"bad"
)
assert
"expected 'module:object'"
in
str
(
exc_info
.
value
)
def
test_lazy_module_not_found
(
self
):
"""Test error when lazy module doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing"
,
lazy
=
"nonexistent_module:Class"
)
with
pytest
.
raises
(
ModuleNotFoundError
):
reg
.
get
(
"missing"
)
def
test_lazy_attribute_not_found
(
self
):
"""Test error when lazy attribute doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing_attr"
,
lazy
=
"os:nonexistent_function"
)
with
pytest
.
raises
(
AttributeError
):
reg
.
get
(
"missing_attr"
)
def
test_multiple_aliases_with_lazy
(
self
):
"""Test that multiple aliases with lazy fails."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"alias1"
,
"alias2"
,
lazy
=
"os:getcwd"
)
assert
"Exactly one alias required"
in
str
(
exc_info
.
value
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment