Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
70314843
Unverified
Commit
70314843
authored
Sep 26, 2025
by
Baber Abbasi
Committed by
GitHub
Sep 26, 2025
Browse files
Merge pull request #3189 from EleutherAI/lazy_reg
refactor registry
parents
73202a2e
930b4253
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1165 additions
and
210 deletions
+1165
-210
lm_eval/__init__.py
lm_eval/__init__.py
+4
-0
lm_eval/api/registry.py
lm_eval/api/registry.py
+532
-170
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+12
-3
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+51
-25
lm_eval/models/hf_steered.py
lm_eval/models/hf_steered.py
+2
-1
lm_eval/models/ibm_watsonx_ai.py
lm_eval/models/ibm_watsonx_ai.py
+2
-2
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+1
-1
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
+3
-3
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
+3
-3
pyproject.toml
pyproject.toml
+1
-1
scripts/build_benchmark.py
scripts/build_benchmark.py
+1
-1
test_registry.py
test_registry.py
+553
-0
No files found.
lm_eval/__init__.py
View file @
70314843
from
.api
import
metrics
,
model
,
registry
# initializes the registries
from
.filters
import
*
__version__
=
"0.4.9.1"
__version__
=
"0.4.9.1"
...
...
lm_eval/api/registry.py
View file @
70314843
This diff is collapsed.
Click to expand it.
lm_eval/filters/__init__.py
View file @
70314843
from
__future__
import
annotations
from
functools
import
partial
from
functools
import
partial
from
typing
import
Optional
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
from
lm_eval.api.registry
import
filter_registry
,
get_filter
from
.
import
custom
,
extraction
,
selection
,
transformation
from
.
import
custom
,
extraction
,
selection
,
transformation
def
build_filter_ensemble
(
def
build_filter_ensemble
(
filter_name
:
str
,
filter_name
:
str
,
components
:
list
[
tuple
[
str
,
Optional
[
dict
[
str
,
Union
[
str
,
int
,
float
]
]]
]],
components
:
list
[
tuple
[
str
,
dict
[
str
,
str
|
int
|
float
]
|
None
]],
)
->
FilterEnsemble
:
)
->
FilterEnsemble
:
"""
"""
Create a filtering pipeline.
Create a filtering pipeline.
...
@@ -21,3 +21,12 @@ def build_filter_ensemble(
...
@@ -21,3 +21,12 @@ def build_filter_ensemble(
partial
(
get_filter
(
func
),
**
(
kwargs
or
{}))
for
func
,
kwargs
in
components
partial
(
get_filter
(
func
),
**
(
kwargs
or
{}))
for
func
,
kwargs
in
components
],
],
)
)
__all__
=
[
"custom"
,
"extraction"
,
"selection"
,
"transformation"
,
"build_filter_ensemble"
,
]
lm_eval/models/__init__.py
View file @
70314843
from
.
import
(
# Models are now lazily loaded via the registry system
anthropic_llms
,
# No need to import them all at once
api_models
,
dummy
,
# Define model mappings for lazy registration
gguf
,
MODEL_MAPPING
=
{
hf_audiolm
,
"anthropic-completions"
:
"lm_eval.models.anthropic_llms:AnthropicLM"
,
hf_steered
,
"anthropic-chat"
:
"lm_eval.models.anthropic_llms:AnthropicChatLM"
,
hf_vlms
,
"anthropic-chat-completions"
:
"lm_eval.models.anthropic_llms:AnthropicCompletionsLM"
,
huggingface
,
"local-completions"
:
"lm_eval.models.openai_completions:LocalCompletionsAPI"
,
ibm_watsonx_ai
,
"local-chat-completions"
:
"lm_eval.models.openai_completions:LocalChatCompletion"
,
mamba_lm
,
"openai-completions"
:
"lm_eval.models.openai_completions:OpenAICompletionsAPI"
,
nemo_lm
,
"openai-chat-completions"
:
"lm_eval.models.openai_completions:OpenAIChatCompletion"
,
neuron_optimum
,
"dummy"
:
"lm_eval.models.dummy:DummyLM"
,
openai_completions
,
"gguf"
:
"lm_eval.models.gguf:GGUFLM"
,
optimum_ipex
,
"ggml"
:
"lm_eval.models.gguf:GGUFLM"
,
optimum_lm
,
"hf-audiolm-qwen"
:
"lm_eval.models.hf_audiolm:HFAudioLM"
,
sglang_causallms
,
"steered"
:
"lm_eval.models.hf_steered:SteeredHF"
,
sglang_generate_API
,
"hf-multimodal"
:
"lm_eval.models.hf_vlms:HFMultimodalLM"
,
textsynth
,
"hf-auto"
:
"lm_eval.models.huggingface:HFLM"
,
vllm_causallms
,
"hf"
:
"lm_eval.models.huggingface:HFLM"
,
vllm_vlms
,
"huggingface"
:
"lm_eval.models.huggingface:HFLM"
,
)
"watsonx_llm"
:
"lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI"
,
"mamba_ssm"
:
"lm_eval.models.mamba_lm:MambaLMWrapper"
,
"nemo_lm"
:
"lm_eval.models.nemo_lm:NeMoLM"
,
# TODO: implement __all__
"neuronx"
:
"lm_eval.models.neuron_optimum:NeuronModelForCausalLM"
,
"ipex"
:
"lm_eval.models.optimum_ipex:IPEXForCausalLM"
,
"openvino"
:
"lm_eval.models.optimum_lm:OptimumLM"
,
"sglang"
:
"lm_eval.models.sglang_causallms:SGLANG"
,
"sglang-generate"
:
"lm_eval.models.sglang_generate_API:SGAPI"
,
"textsynth"
:
"lm_eval.models.textsynth:TextSynthLM"
,
"vllm"
:
"lm_eval.models.vllm_causallms:VLLM"
,
"vllm-vlm"
:
"lm_eval.models.vllm_vlms:VLLM_VLM"
,
}
# Register all models lazily
def
_register_all_models
():
"""Register all known models lazily in the registry."""
from
lm_eval.api.registry
import
model_registry
for
name
,
path
in
MODEL_MAPPING
.
items
():
# Only register if not already present (avoids conflicts when modules are imported)
if
name
not
in
model_registry
:
# Register the lazy placeholder using lazy parameter
model_registry
.
register
(
name
,
lazy
=
path
)
# Call registration on module import
_register_all_models
()
__all__
=
[
"MODEL_MAPPING"
]
try
:
try
:
...
...
lm_eval/models/hf_steered.py
View file @
70314843
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
from
peft.peft_model
import
PeftModel
from
peft.peft_model
import
PeftModel
...
...
lm_eval/models/ibm_watsonx_ai.py
View file @
70314843
...
@@ -3,7 +3,7 @@ import json
...
@@ -3,7 +3,7 @@ import json
import
logging
import
logging
import
os
import
os
import
warnings
import
warnings
from
functools
import
lru_
cache
from
functools
import
cache
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise
ValueError
(
error_msg
)
raise
ValueError
(
error_msg
)
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
"""
"""
Retrieves Watsonx API credentials from environmental variables.
Retrieves Watsonx API credentials from environmental variables.
...
...
lm_eval/models/vllm_causallms.py
View file @
70314843
...
@@ -42,7 +42,7 @@ try:
...
@@ -42,7 +42,7 @@ try:
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
p
ass
p
rint
(
"njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd"
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
pass
pass
...
...
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
View file @
70314843
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
with
open
(
GRAMMAR_FILE
)
as
f
:
...
@@ -556,8 +556,8 @@ class STRIPS:
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
View file @
70314843
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
with
open
(
GRAMMAR_FILE
)
as
f
:
...
@@ -556,8 +556,8 @@ class STRIPS:
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
pyproject.toml
View file @
70314843
...
@@ -121,7 +121,7 @@ lint.fixable = ["I001", "F401", "UP"]
...
@@ -121,7 +121,7 @@ lint.fixable = ["I001", "F401", "UP"]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E741"
]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E741"
]
[tool.ruff.lint.extend-per-file-ignores]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
,
"F405"
]
[tool.ruff.lint.isort]
[tool.ruff.lint.isort]
combine-as-imports
=
true
combine-as-imports
=
true
...
...
scripts/build_benchmark.py
View file @
70314843
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from
tqdm
import
tqdm
from
tqdm
import
tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registry
v2
import ALL_TASKS
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
test_registry.py
0 → 100644
View file @
70314843
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""
import
threading
import
pytest
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
(
MetricSpec
,
Registry
,
get_metric
,
metric_agg_registry
,
metric_registry
,
model_registry
,
register_metric
,
)
# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics
class
TestBasicRegistry
:
"""Test basic registry functionality."""
def
test_create_registry
(
self
):
"""Test creating a basic registry."""
reg
=
Registry
(
"test"
)
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_decorator_registration
(
self
):
"""Test decorator-based registration."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"my_class"
)
class
MyClass
:
pass
assert
"my_class"
in
reg
assert
reg
.
get
(
"my_class"
)
==
MyClass
assert
reg
[
"my_class"
]
==
MyClass
def
test_decorator_multiple_aliases
(
self
):
"""Test decorator with multiple aliases."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"alias1"
,
"alias2"
,
"alias3"
)
class
MyClass
:
pass
assert
reg
.
get
(
"alias1"
)
==
MyClass
assert
reg
.
get
(
"alias2"
)
==
MyClass
assert
reg
.
get
(
"alias3"
)
==
MyClass
def
test_decorator_auto_name
(
self
):
"""Test decorator using class name when no alias provided."""
reg
=
Registry
(
"test"
)
@
reg
.
register
()
class
AutoNamedClass
:
pass
assert
reg
.
get
(
"AutoNamedClass"
)
==
AutoNamedClass
def
test_lazy_registration
(
self
):
"""Test lazy loading with module paths."""
reg
=
Registry
(
"test"
)
# Register with lazy loading
reg
.
register
(
"join"
,
lazy
=
"os.path:join"
)
# Check it's stored as a string
assert
isinstance
(
reg
.
_objs
[
"join"
],
str
)
# Access triggers materialization
result
=
reg
.
get
(
"join"
)
import
os
assert
result
==
os
.
path
.
join
assert
callable
(
result
)
def
test_direct_registration
(
self
):
"""Test direct object registration."""
reg
=
Registry
(
"test"
)
class
DirectClass
:
pass
obj
=
DirectClass
()
reg
.
register
(
"direct"
,
lazy
=
obj
)
assert
reg
.
get
(
"direct"
)
==
obj
def
test_metadata_removed
(
self
):
"""Test that metadata parameter is removed from generic registry."""
reg
=
Registry
(
"test"
)
# Should work without metadata parameter
@
reg
.
register
(
"test_class"
)
class
TestClass
:
pass
assert
"test_class"
in
reg
assert
reg
.
get
(
"test_class"
)
==
TestClass
def
test_unknown_key_error
(
self
):
"""Test error when accessing unknown key."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
KeyError
)
as
exc_info
:
reg
.
get
(
"unknown"
)
assert
"Unknown test 'unknown'"
in
str
(
exc_info
.
value
)
assert
"Available:"
in
str
(
exc_info
.
value
)
def
test_iteration
(
self
):
"""Test registry iteration."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"a"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"b"
,
lazy
=
"os:getenv"
)
reg
.
register
(
"c"
,
lazy
=
"os:getpid"
)
assert
list
(
reg
)
==
[
"a"
,
"b"
,
"c"
]
assert
len
(
reg
)
==
3
# Test items()
items
=
list
(
reg
.
items
())
assert
len
(
items
)
==
3
assert
items
[
0
][
0
]
==
"a"
assert
isinstance
(
items
[
0
][
1
],
str
)
# Still lazy
def
test_mapping_protocol
(
self
):
"""Test that registry implements mapping protocol."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"test"
,
lazy
=
"os:getcwd"
)
# __getitem__
assert
reg
[
"test"
]
==
reg
.
get
(
"test"
)
# __contains__
assert
"test"
in
reg
assert
"missing"
not
in
reg
# __iter__ and __len__ tested above
class
TestTypeConstraints
:
"""Test type checking and base class constraints."""
def
test_base_class_constraint
(
self
):
"""Test base class validation."""
# Define a base class
class
BaseClass
:
pass
class
GoodSubclass
(
BaseClass
):
pass
class
BadClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Should work - correct subclass
@
reg
.
register
(
"good"
)
class
GoodInline
(
BaseClass
):
pass
# Should fail - wrong type
with
pytest
.
raises
(
TypeError
)
as
exc_info
:
@
reg
.
register
(
"bad"
)
class
BadInline
:
pass
assert
"must inherit from"
in
str
(
exc_info
.
value
)
def
test_lazy_type_check
(
self
):
"""Test that type checking happens on materialization for lazy entries."""
class
BaseClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Register a lazy entry that will fail type check
reg
.
register
(
"bad_lazy"
,
lazy
=
"os.path:join"
)
# Should fail when accessed - the error message varies
with
pytest
.
raises
(
TypeError
):
reg
.
get
(
"bad_lazy"
)
class
TestCollisionHandling
:
"""Test registration collision scenarios."""
def
test_identical_registration
(
self
):
"""Test that identical re-registration is allowed."""
reg
=
Registry
(
"test"
)
class
MyClass
:
pass
# First registration
reg
.
register
(
"test"
,
lazy
=
MyClass
)
# Identical re-registration should work
reg
.
register
(
"test"
,
lazy
=
MyClass
)
assert
reg
.
get
(
"test"
)
==
MyClass
def
test_different_registration_fails
(
self
):
"""Test that different re-registration fails."""
reg
=
Registry
(
"test"
)
class
Class1
:
pass
class
Class2
:
pass
reg
.
register
(
"test"
,
lazy
=
Class1
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"test"
,
lazy
=
Class2
)
assert
"already registered"
in
str
(
exc_info
.
value
)
def
test_lazy_to_concrete_upgrade
(
self
):
"""Test that lazy placeholder can be upgraded to concrete class."""
reg
=
Registry
(
"test"
)
# Register lazy
reg
.
register
(
"myclass"
,
lazy
=
"test_registry:MyUpgradeClass"
)
# Define and register concrete - should work
@
reg
.
register
(
"myclass"
)
class
MyUpgradeClass
:
pass
assert
reg
.
get
(
"myclass"
)
==
MyUpgradeClass
class
TestThreadSafety
:
"""Test thread safety of registry operations."""
def
test_concurrent_access
(
self
):
"""Test concurrent access to lazy entries."""
reg
=
Registry
(
"test"
)
# Register lazy entry
reg
.
register
(
"concurrent"
,
lazy
=
"os.path:join"
)
results
=
[]
errors
=
[]
def
access_item
():
try
:
result
=
reg
.
get
(
"concurrent"
)
results
.
append
(
result
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads
threads
=
[]
for
_
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
access_item
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
results
)
==
10
# All should get the same object
assert
all
(
r
==
results
[
0
]
for
r
in
results
)
def
test_concurrent_registration
(
self
):
"""Test concurrent registration doesn't cause issues."""
reg
=
Registry
(
"test"
)
errors
=
[]
def
register_item
(
name
,
value
):
try
:
reg
.
register
(
name
,
lazy
=
value
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads with different registrations
threads
=
[]
for
i
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
register_item
,
args
=
(
f
"item_
{
i
}
"
,
f
"module
{
i
}
:Class
{
i
}
"
)
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
reg
)
==
10
class
TestMetricRegistry
:
"""Test metric-specific registry functionality."""
def
test_metric_spec
(
self
):
"""Test MetricSpec dataclass."""
def
compute_fn
(
items
):
return
[
1
for
_
in
items
]
def
agg_fn
(
values
):
return
sum
(
values
)
/
len
(
values
)
spec
=
MetricSpec
(
compute
=
compute_fn
,
aggregate
=
agg_fn
,
higher_is_better
=
True
,
output_type
=
"probability"
,
)
assert
spec
.
compute
==
compute_fn
assert
spec
.
aggregate
==
agg_fn
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"probability"
def
test_register_metric_decorator
(
self
):
"""Test @register_metric decorator."""
# Register aggregation function first
@
metric_agg_registry
.
register
(
"test_mean"
)
def
test_mean
(
values
):
return
sum
(
values
)
/
len
(
values
)
if
values
else
0.0
# Register metric
@
register_metric
(
metric
=
"test_accuracy"
,
aggregation
=
"test_mean"
,
higher_is_better
=
True
,
output_type
=
"accuracy"
,
)
def
compute_accuracy
(
items
):
return
[
1
if
item
[
"pred"
]
==
item
[
"gold"
]
else
0
for
item
in
items
]
# Check registration
assert
"test_accuracy"
in
metric_registry
spec
=
metric_registry
.
get
(
"test_accuracy"
)
assert
isinstance
(
spec
,
MetricSpec
)
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"accuracy"
# Test compute function
items
=
[
{
"pred"
:
"a"
,
"gold"
:
"a"
},
{
"pred"
:
"b"
,
"gold"
:
"b"
},
{
"pred"
:
"c"
,
"gold"
:
"d"
},
]
result
=
spec
.
compute
(
items
)
assert
result
==
[
1
,
1
,
0
]
# Test aggregation
agg_result
=
spec
.
aggregate
(
result
)
assert
agg_result
==
2
/
3
def
test_metric_without_aggregation
(
self
):
"""Test metric registration without aggregation."""
@
register_metric
(
metric
=
"no_agg"
,
higher_is_better
=
False
)
def
compute_something
(
items
):
return
[
len
(
item
)
for
item
in
items
]
spec
=
metric_registry
.
get
(
"no_agg"
)
# Should raise NotImplementedError when aggregate is called
with
pytest
.
raises
(
NotImplementedError
)
as
exc_info
:
spec
.
aggregate
([
1
,
2
,
3
])
assert
"No aggregation function specified"
in
str
(
exc_info
.
value
)
def
test_get_metric_helper
(
self
):
"""Test get_metric helper function."""
@
register_metric
(
metric
=
"helper_test"
,
aggregation
=
"mean"
,
# Assuming 'mean' exists in metric_agg_registry
)
def
compute_helper
(
items
):
return
items
# get_metric returns just the compute function
compute_fn
=
get_metric
(
"helper_test"
)
assert
callable
(
compute_fn
)
assert
compute_fn
([
1
,
2
,
3
])
==
[
1
,
2
,
3
]
class
TestRegistryUtilities
:
"""Test utility methods."""
def
test_freeze
(
self
):
"""Test freezing a registry."""
reg
=
Registry
(
"test"
)
# Add some items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
# Freeze the registry
reg
.
freeze
()
# Should not be able to register new items
with
pytest
.
raises
(
TypeError
):
reg
.
_objs
[
"new"
]
=
"value"
# Should still be able to access items
assert
"item1"
in
reg
assert
callable
(
reg
.
get
(
"item1"
))
def
test_clear
(
self
):
"""Test clearing a registry."""
reg
=
Registry
(
"test"
)
# Add items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
assert
len
(
reg
)
==
2
# Clear
reg
.
_clear
()
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_origin
(
self
):
"""Test origin tracking."""
reg
=
Registry
(
"test"
)
# Lazy entry - no origin
reg
.
register
(
"lazy"
,
lazy
=
"os:getcwd"
)
assert
reg
.
origin
(
"lazy"
)
is
None
# Concrete class - should have origin
@
reg
.
register
(
"concrete"
)
class
ConcreteClass
:
pass
origin
=
reg
.
origin
(
"concrete"
)
assert
origin
is
not
None
assert
"test_registry.py"
in
origin
assert
":"
in
origin
# Has line number
class
TestBackwardCompatibility
:
"""Test backward compatibility features."""
def
test_model_registry_alias
(
self
):
"""Test MODEL_REGISTRY backward compatibility."""
from
lm_eval.api.registry
import
MODEL_REGISTRY
# Should be same object as model_registry
assert
MODEL_REGISTRY
is
model_registry
# Should reflect current state
before_count
=
len
(
MODEL_REGISTRY
)
# Add new model
@
model_registry
.
register
(
"test_model_compat"
)
class
TestModelCompat
(
LM
):
pass
# MODEL_REGISTRY should immediately reflect the change
assert
len
(
MODEL_REGISTRY
)
==
before_count
+
1
assert
"test_model_compat"
in
MODEL_REGISTRY
def
test_legacy_functions
(
self
):
"""Test legacy helper functions."""
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
get_model
,
register_model
,
)
# register_model should work
@
register_model
(
"legacy_model"
)
class
LegacyModel
(
LM
):
pass
# get_model should work
assert
get_model
(
"legacy_model"
)
==
LegacyModel
# Check other aliases
assert
DEFAULT_METRIC_REGISTRY
is
metric_registry
assert
AGGREGATION_REGISTRY
is
metric_agg_registry
class
TestEdgeCases
:
"""Test edge cases and error conditions."""
def
test_invalid_lazy_format
(
self
):
"""Test error on invalid lazy format."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"bad"
,
lazy
=
"no_colon_here"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
get
(
"bad"
)
assert
"expected 'module:object'"
in
str
(
exc_info
.
value
)
def
test_lazy_module_not_found
(
self
):
"""Test error when lazy module doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing"
,
lazy
=
"nonexistent_module:Class"
)
with
pytest
.
raises
(
ModuleNotFoundError
):
reg
.
get
(
"missing"
)
def
test_lazy_attribute_not_found
(
self
):
"""Test error when lazy attribute doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing_attr"
,
lazy
=
"os:nonexistent_function"
)
with
pytest
.
raises
(
AttributeError
):
reg
.
get
(
"missing_attr"
)
def
test_multiple_aliases_with_lazy
(
self
):
"""Test that multiple aliases with lazy fails."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"alias1"
,
"alias2"
,
lazy
=
"os:getcwd"
)
assert
"Exactly one alias required"
in
str
(
exc_info
.
value
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment