Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
70314843
Unverified
Commit
70314843
authored
Sep 26, 2025
by
Baber Abbasi
Committed by
GitHub
Sep 26, 2025
Browse files
Merge pull request #3189 from EleutherAI/lazy_reg
refactor registry
parents
73202a2e
930b4253
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1165 additions
and
210 deletions
+1165
-210
lm_eval/__init__.py
lm_eval/__init__.py
+4
-0
lm_eval/api/registry.py
lm_eval/api/registry.py
+532
-170
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+12
-3
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+51
-25
lm_eval/models/hf_steered.py
lm_eval/models/hf_steered.py
+2
-1
lm_eval/models/ibm_watsonx_ai.py
lm_eval/models/ibm_watsonx_ai.py
+2
-2
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+1
-1
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
+3
-3
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
+3
-3
pyproject.toml
pyproject.toml
+1
-1
scripts/build_benchmark.py
scripts/build_benchmark.py
+1
-1
test_registry.py
test_registry.py
+553
-0
No files found.
lm_eval/__init__.py
View file @
70314843
from
.api
import
metrics
,
model
,
registry
# initializes the registries
from
.filters
import
*
__version__
=
"0.4.9.1"
...
...
lm_eval/api/registry.py
View file @
70314843
This diff is collapsed.
Click to expand it.
lm_eval/filters/__init__.py
View file @
70314843
from
__future__
import
annotations
from
functools
import
partial
from
typing
import
Optional
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
from
lm_eval.api.registry
import
filter_registry
,
get_filter
from
.
import
custom
,
extraction
,
selection
,
transformation
def
build_filter_ensemble
(
filter_name
:
str
,
components
:
list
[
tuple
[
str
,
Optional
[
dict
[
str
,
Union
[
str
,
int
,
float
]
]]
]],
components
:
list
[
tuple
[
str
,
dict
[
str
,
str
|
int
|
float
]
|
None
]],
)
->
FilterEnsemble
:
"""
Create a filtering pipeline.
...
...
@@ -21,3 +21,12 @@ def build_filter_ensemble(
partial
(
get_filter
(
func
),
**
(
kwargs
or
{}))
for
func
,
kwargs
in
components
],
)
__all__
=
[
"custom"
,
"extraction"
,
"selection"
,
"transformation"
,
"build_filter_ensemble"
,
]
lm_eval/models/__init__.py
View file @
70314843
from
.
import
(
anthropic_llms
,
api_models
,
dummy
,
gguf
,
hf_audiolm
,
hf_steered
,
hf_vlms
,
huggingface
,
ibm_watsonx_ai
,
mamba_lm
,
nemo_lm
,
neuron_optimum
,
openai_completions
,
optimum_ipex
,
optimum_lm
,
sglang_causallms
,
sglang_generate_API
,
textsynth
,
vllm_causallms
,
vllm_vlms
,
)
# TODO: implement __all__
# Models are now lazily loaded via the registry system
# No need to import them all at once
# Define model mappings for lazy registration
MODEL_MAPPING
=
{
"anthropic-completions"
:
"lm_eval.models.anthropic_llms:AnthropicLM"
,
"anthropic-chat"
:
"lm_eval.models.anthropic_llms:AnthropicChatLM"
,
"anthropic-chat-completions"
:
"lm_eval.models.anthropic_llms:AnthropicCompletionsLM"
,
"local-completions"
:
"lm_eval.models.openai_completions:LocalCompletionsAPI"
,
"local-chat-completions"
:
"lm_eval.models.openai_completions:LocalChatCompletion"
,
"openai-completions"
:
"lm_eval.models.openai_completions:OpenAICompletionsAPI"
,
"openai-chat-completions"
:
"lm_eval.models.openai_completions:OpenAIChatCompletion"
,
"dummy"
:
"lm_eval.models.dummy:DummyLM"
,
"gguf"
:
"lm_eval.models.gguf:GGUFLM"
,
"ggml"
:
"lm_eval.models.gguf:GGUFLM"
,
"hf-audiolm-qwen"
:
"lm_eval.models.hf_audiolm:HFAudioLM"
,
"steered"
:
"lm_eval.models.hf_steered:SteeredHF"
,
"hf-multimodal"
:
"lm_eval.models.hf_vlms:HFMultimodalLM"
,
"hf-auto"
:
"lm_eval.models.huggingface:HFLM"
,
"hf"
:
"lm_eval.models.huggingface:HFLM"
,
"huggingface"
:
"lm_eval.models.huggingface:HFLM"
,
"watsonx_llm"
:
"lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI"
,
"mamba_ssm"
:
"lm_eval.models.mamba_lm:MambaLMWrapper"
,
"nemo_lm"
:
"lm_eval.models.nemo_lm:NeMoLM"
,
"neuronx"
:
"lm_eval.models.neuron_optimum:NeuronModelForCausalLM"
,
"ipex"
:
"lm_eval.models.optimum_ipex:IPEXForCausalLM"
,
"openvino"
:
"lm_eval.models.optimum_lm:OptimumLM"
,
"sglang"
:
"lm_eval.models.sglang_causallms:SGLANG"
,
"sglang-generate"
:
"lm_eval.models.sglang_generate_API:SGAPI"
,
"textsynth"
:
"lm_eval.models.textsynth:TextSynthLM"
,
"vllm"
:
"lm_eval.models.vllm_causallms:VLLM"
,
"vllm-vlm"
:
"lm_eval.models.vllm_vlms:VLLM_VLM"
,
}
# Register all models lazily
def
_register_all_models
():
"""Register all known models lazily in the registry."""
from
lm_eval.api.registry
import
model_registry
for
name
,
path
in
MODEL_MAPPING
.
items
():
# Only register if not already present (avoids conflicts when modules are imported)
if
name
not
in
model_registry
:
# Register the lazy placeholder using lazy parameter
model_registry
.
register
(
name
,
lazy
=
path
)
# Call registration on module import
_register_all_models
()
__all__
=
[
"MODEL_MAPPING"
]
try
:
...
...
lm_eval/models/hf_steered.py
View file @
70314843
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
functools
import
partial
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
from
peft.peft_model
import
PeftModel
...
...
lm_eval/models/ibm_watsonx_ai.py
View file @
70314843
...
...
@@ -3,7 +3,7 @@ import json
import
logging
import
os
import
warnings
from
functools
import
lru_
cache
from
functools
import
cache
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
tqdm
import
tqdm
...
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise
ValueError
(
error_msg
)
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
"""
Retrieves Watsonx API credentials from environmental variables.
...
...
lm_eval/models/vllm_causallms.py
View file @
70314843
...
...
@@ -42,7 +42,7 @@ try:
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
except
ModuleNotFoundError
:
p
ass
p
rint
(
"njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd"
)
if
TYPE_CHECKING
:
pass
...
...
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
View file @
70314843
...
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
...
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
View file @
70314843
...
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
...
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
pyproject.toml
View file @
70314843
...
...
@@ -121,7 +121,7 @@ lint.fixable = ["I001", "F401", "UP"]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E741"
]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
,
"F405"
]
[tool.ruff.lint.isort]
combine-as-imports
=
true
...
...
scripts/build_benchmark.py
View file @
70314843
...
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from
tqdm
import
tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registry
v2
import ALL_TASKS
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
test_registry.py
0 → 100644
View file @
70314843
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""
import
threading
import
pytest
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
(
MetricSpec
,
Registry
,
get_metric
,
metric_agg_registry
,
metric_registry
,
model_registry
,
register_metric
,
)
# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics
class
TestBasicRegistry
:
"""Test basic registry functionality."""
def
test_create_registry
(
self
):
"""Test creating a basic registry."""
reg
=
Registry
(
"test"
)
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_decorator_registration
(
self
):
"""Test decorator-based registration."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"my_class"
)
class
MyClass
:
pass
assert
"my_class"
in
reg
assert
reg
.
get
(
"my_class"
)
==
MyClass
assert
reg
[
"my_class"
]
==
MyClass
def
test_decorator_multiple_aliases
(
self
):
"""Test decorator with multiple aliases."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"alias1"
,
"alias2"
,
"alias3"
)
class
MyClass
:
pass
assert
reg
.
get
(
"alias1"
)
==
MyClass
assert
reg
.
get
(
"alias2"
)
==
MyClass
assert
reg
.
get
(
"alias3"
)
==
MyClass
def
test_decorator_auto_name
(
self
):
"""Test decorator using class name when no alias provided."""
reg
=
Registry
(
"test"
)
@
reg
.
register
()
class
AutoNamedClass
:
pass
assert
reg
.
get
(
"AutoNamedClass"
)
==
AutoNamedClass
def
test_lazy_registration
(
self
):
"""Test lazy loading with module paths."""
reg
=
Registry
(
"test"
)
# Register with lazy loading
reg
.
register
(
"join"
,
lazy
=
"os.path:join"
)
# Check it's stored as a string
assert
isinstance
(
reg
.
_objs
[
"join"
],
str
)
# Access triggers materialization
result
=
reg
.
get
(
"join"
)
import
os
assert
result
==
os
.
path
.
join
assert
callable
(
result
)
def
test_direct_registration
(
self
):
"""Test direct object registration."""
reg
=
Registry
(
"test"
)
class
DirectClass
:
pass
obj
=
DirectClass
()
reg
.
register
(
"direct"
,
lazy
=
obj
)
assert
reg
.
get
(
"direct"
)
==
obj
def
test_metadata_removed
(
self
):
"""Test that metadata parameter is removed from generic registry."""
reg
=
Registry
(
"test"
)
# Should work without metadata parameter
@
reg
.
register
(
"test_class"
)
class
TestClass
:
pass
assert
"test_class"
in
reg
assert
reg
.
get
(
"test_class"
)
==
TestClass
def
test_unknown_key_error
(
self
):
"""Test error when accessing unknown key."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
KeyError
)
as
exc_info
:
reg
.
get
(
"unknown"
)
assert
"Unknown test 'unknown'"
in
str
(
exc_info
.
value
)
assert
"Available:"
in
str
(
exc_info
.
value
)
def
test_iteration
(
self
):
"""Test registry iteration."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"a"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"b"
,
lazy
=
"os:getenv"
)
reg
.
register
(
"c"
,
lazy
=
"os:getpid"
)
assert
list
(
reg
)
==
[
"a"
,
"b"
,
"c"
]
assert
len
(
reg
)
==
3
# Test items()
items
=
list
(
reg
.
items
())
assert
len
(
items
)
==
3
assert
items
[
0
][
0
]
==
"a"
assert
isinstance
(
items
[
0
][
1
],
str
)
# Still lazy
def
test_mapping_protocol
(
self
):
"""Test that registry implements mapping protocol."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"test"
,
lazy
=
"os:getcwd"
)
# __getitem__
assert
reg
[
"test"
]
==
reg
.
get
(
"test"
)
# __contains__
assert
"test"
in
reg
assert
"missing"
not
in
reg
# __iter__ and __len__ tested above
class
TestTypeConstraints
:
"""Test type checking and base class constraints."""
def
test_base_class_constraint
(
self
):
"""Test base class validation."""
# Define a base class
class
BaseClass
:
pass
class
GoodSubclass
(
BaseClass
):
pass
class
BadClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Should work - correct subclass
@
reg
.
register
(
"good"
)
class
GoodInline
(
BaseClass
):
pass
# Should fail - wrong type
with
pytest
.
raises
(
TypeError
)
as
exc_info
:
@
reg
.
register
(
"bad"
)
class
BadInline
:
pass
assert
"must inherit from"
in
str
(
exc_info
.
value
)
def
test_lazy_type_check
(
self
):
"""Test that type checking happens on materialization for lazy entries."""
class
BaseClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Register a lazy entry that will fail type check
reg
.
register
(
"bad_lazy"
,
lazy
=
"os.path:join"
)
# Should fail when accessed - the error message varies
with
pytest
.
raises
(
TypeError
):
reg
.
get
(
"bad_lazy"
)
class
TestCollisionHandling
:
"""Test registration collision scenarios."""
def
test_identical_registration
(
self
):
"""Test that identical re-registration is allowed."""
reg
=
Registry
(
"test"
)
class
MyClass
:
pass
# First registration
reg
.
register
(
"test"
,
lazy
=
MyClass
)
# Identical re-registration should work
reg
.
register
(
"test"
,
lazy
=
MyClass
)
assert
reg
.
get
(
"test"
)
==
MyClass
def
test_different_registration_fails
(
self
):
"""Test that different re-registration fails."""
reg
=
Registry
(
"test"
)
class
Class1
:
pass
class
Class2
:
pass
reg
.
register
(
"test"
,
lazy
=
Class1
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"test"
,
lazy
=
Class2
)
assert
"already registered"
in
str
(
exc_info
.
value
)
def
test_lazy_to_concrete_upgrade
(
self
):
"""Test that lazy placeholder can be upgraded to concrete class."""
reg
=
Registry
(
"test"
)
# Register lazy
reg
.
register
(
"myclass"
,
lazy
=
"test_registry:MyUpgradeClass"
)
# Define and register concrete - should work
@
reg
.
register
(
"myclass"
)
class
MyUpgradeClass
:
pass
assert
reg
.
get
(
"myclass"
)
==
MyUpgradeClass
class
TestThreadSafety
:
"""Test thread safety of registry operations."""
def
test_concurrent_access
(
self
):
"""Test concurrent access to lazy entries."""
reg
=
Registry
(
"test"
)
# Register lazy entry
reg
.
register
(
"concurrent"
,
lazy
=
"os.path:join"
)
results
=
[]
errors
=
[]
def
access_item
():
try
:
result
=
reg
.
get
(
"concurrent"
)
results
.
append
(
result
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads
threads
=
[]
for
_
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
access_item
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
results
)
==
10
# All should get the same object
assert
all
(
r
==
results
[
0
]
for
r
in
results
)
def
test_concurrent_registration
(
self
):
"""Test concurrent registration doesn't cause issues."""
reg
=
Registry
(
"test"
)
errors
=
[]
def
register_item
(
name
,
value
):
try
:
reg
.
register
(
name
,
lazy
=
value
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads with different registrations
threads
=
[]
for
i
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
register_item
,
args
=
(
f
"item_
{
i
}
"
,
f
"module
{
i
}
:Class
{
i
}
"
)
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
reg
)
==
10
class
TestMetricRegistry
:
"""Test metric-specific registry functionality."""
def
test_metric_spec
(
self
):
"""Test MetricSpec dataclass."""
def
compute_fn
(
items
):
return
[
1
for
_
in
items
]
def
agg_fn
(
values
):
return
sum
(
values
)
/
len
(
values
)
spec
=
MetricSpec
(
compute
=
compute_fn
,
aggregate
=
agg_fn
,
higher_is_better
=
True
,
output_type
=
"probability"
,
)
assert
spec
.
compute
==
compute_fn
assert
spec
.
aggregate
==
agg_fn
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"probability"
def
test_register_metric_decorator
(
self
):
"""Test @register_metric decorator."""
# Register aggregation function first
@
metric_agg_registry
.
register
(
"test_mean"
)
def
test_mean
(
values
):
return
sum
(
values
)
/
len
(
values
)
if
values
else
0.0
# Register metric
@
register_metric
(
metric
=
"test_accuracy"
,
aggregation
=
"test_mean"
,
higher_is_better
=
True
,
output_type
=
"accuracy"
,
)
def
compute_accuracy
(
items
):
return
[
1
if
item
[
"pred"
]
==
item
[
"gold"
]
else
0
for
item
in
items
]
# Check registration
assert
"test_accuracy"
in
metric_registry
spec
=
metric_registry
.
get
(
"test_accuracy"
)
assert
isinstance
(
spec
,
MetricSpec
)
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"accuracy"
# Test compute function
items
=
[
{
"pred"
:
"a"
,
"gold"
:
"a"
},
{
"pred"
:
"b"
,
"gold"
:
"b"
},
{
"pred"
:
"c"
,
"gold"
:
"d"
},
]
result
=
spec
.
compute
(
items
)
assert
result
==
[
1
,
1
,
0
]
# Test aggregation
agg_result
=
spec
.
aggregate
(
result
)
assert
agg_result
==
2
/
3
def
test_metric_without_aggregation
(
self
):
"""Test metric registration without aggregation."""
@
register_metric
(
metric
=
"no_agg"
,
higher_is_better
=
False
)
def
compute_something
(
items
):
return
[
len
(
item
)
for
item
in
items
]
spec
=
metric_registry
.
get
(
"no_agg"
)
# Should raise NotImplementedError when aggregate is called
with
pytest
.
raises
(
NotImplementedError
)
as
exc_info
:
spec
.
aggregate
([
1
,
2
,
3
])
assert
"No aggregation function specified"
in
str
(
exc_info
.
value
)
def
test_get_metric_helper
(
self
):
"""Test get_metric helper function."""
@
register_metric
(
metric
=
"helper_test"
,
aggregation
=
"mean"
,
# Assuming 'mean' exists in metric_agg_registry
)
def
compute_helper
(
items
):
return
items
# get_metric returns just the compute function
compute_fn
=
get_metric
(
"helper_test"
)
assert
callable
(
compute_fn
)
assert
compute_fn
([
1
,
2
,
3
])
==
[
1
,
2
,
3
]
class
TestRegistryUtilities
:
"""Test utility methods."""
def
test_freeze
(
self
):
"""Test freezing a registry."""
reg
=
Registry
(
"test"
)
# Add some items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
# Freeze the registry
reg
.
freeze
()
# Should not be able to register new items
with
pytest
.
raises
(
TypeError
):
reg
.
_objs
[
"new"
]
=
"value"
# Should still be able to access items
assert
"item1"
in
reg
assert
callable
(
reg
.
get
(
"item1"
))
def
test_clear
(
self
):
"""Test clearing a registry."""
reg
=
Registry
(
"test"
)
# Add items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
assert
len
(
reg
)
==
2
# Clear
reg
.
_clear
()
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_origin
(
self
):
"""Test origin tracking."""
reg
=
Registry
(
"test"
)
# Lazy entry - no origin
reg
.
register
(
"lazy"
,
lazy
=
"os:getcwd"
)
assert
reg
.
origin
(
"lazy"
)
is
None
# Concrete class - should have origin
@
reg
.
register
(
"concrete"
)
class
ConcreteClass
:
pass
origin
=
reg
.
origin
(
"concrete"
)
assert
origin
is
not
None
assert
"test_registry.py"
in
origin
assert
":"
in
origin
# Has line number
class
TestBackwardCompatibility
:
"""Test backward compatibility features."""
def
test_model_registry_alias
(
self
):
"""Test MODEL_REGISTRY backward compatibility."""
from
lm_eval.api.registry
import
MODEL_REGISTRY
# Should be same object as model_registry
assert
MODEL_REGISTRY
is
model_registry
# Should reflect current state
before_count
=
len
(
MODEL_REGISTRY
)
# Add new model
@
model_registry
.
register
(
"test_model_compat"
)
class
TestModelCompat
(
LM
):
pass
# MODEL_REGISTRY should immediately reflect the change
assert
len
(
MODEL_REGISTRY
)
==
before_count
+
1
assert
"test_model_compat"
in
MODEL_REGISTRY
def
test_legacy_functions
(
self
):
"""Test legacy helper functions."""
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
get_model
,
register_model
,
)
# register_model should work
@
register_model
(
"legacy_model"
)
class
LegacyModel
(
LM
):
pass
# get_model should work
assert
get_model
(
"legacy_model"
)
==
LegacyModel
# Check other aliases
assert
DEFAULT_METRIC_REGISTRY
is
metric_registry
assert
AGGREGATION_REGISTRY
is
metric_agg_registry
class
TestEdgeCases
:
"""Test edge cases and error conditions."""
def
test_invalid_lazy_format
(
self
):
"""Test error on invalid lazy format."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"bad"
,
lazy
=
"no_colon_here"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
get
(
"bad"
)
assert
"expected 'module:object'"
in
str
(
exc_info
.
value
)
def
test_lazy_module_not_found
(
self
):
"""Test error when lazy module doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing"
,
lazy
=
"nonexistent_module:Class"
)
with
pytest
.
raises
(
ModuleNotFoundError
):
reg
.
get
(
"missing"
)
def
test_lazy_attribute_not_found
(
self
):
"""Test error when lazy attribute doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing_attr"
,
lazy
=
"os:nonexistent_function"
)
with
pytest
.
raises
(
AttributeError
):
reg
.
get
(
"missing_attr"
)
def
test_multiple_aliases_with_lazy
(
self
):
"""Test that multiple aliases with lazy fails."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"alias1"
,
"alias2"
,
lazy
=
"os:getcwd"
)
assert
"Exactly one alias required"
in
str
(
exc_info
.
value
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment