Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
70314843
Unverified
Commit
70314843
authored
Sep 26, 2025
by
Baber Abbasi
Committed by
GitHub
Sep 26, 2025
Browse files
Merge pull request #3189 from EleutherAI/lazy_reg
refactor registry
parents
73202a2e
930b4253
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1165 additions
and
210 deletions
+1165
-210
lm_eval/__init__.py
lm_eval/__init__.py
+4
-0
lm_eval/api/registry.py
lm_eval/api/registry.py
+532
-170
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+12
-3
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+51
-25
lm_eval/models/hf_steered.py
lm_eval/models/hf_steered.py
+2
-1
lm_eval/models/ibm_watsonx_ai.py
lm_eval/models/ibm_watsonx_ai.py
+2
-2
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+1
-1
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
+3
-3
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
+3
-3
pyproject.toml
pyproject.toml
+1
-1
scripts/build_benchmark.py
scripts/build_benchmark.py
+1
-1
test_registry.py
test_registry.py
+553
-0
No files found.
lm_eval/__init__.py
View file @
70314843
from
.api
import
metrics
,
model
,
registry
# initializes the registries
from
.filters
import
*
__version__
=
"0.4.9.1"
...
...
lm_eval/api/registry.py
View file @
70314843
"""Registry system for lm_eval components.
This module provides a centralized registration system for models, tasks, metrics,
filters, and other components in the lm_eval framework. The registry supports:
- Lazy loading with placeholders to improve startup time
- Type checking and validation
- Thread-safe registration and lookup
- Plugin discovery via entry points
- Backwards compatibility with legacy registration patterns
## Usage Examples
### Registering a Model
```python
from lm_eval.api.registry import register_model
from lm_eval.api.model import LM
@register_model("my-model")
class MyModel(LM):
def __init__(self, **kwargs):
...
```
### Registering a Metric
```python
from lm_eval.api.registry import register_metric
@register_metric(
metric="my_accuracy",
aggregation="mean",
higher_is_better=True
)
def my_accuracy_fn(items):
...
```
### Registering with Lazy Loading
```python
# Register without importing the actual implementation
model_registry.register("lazy-model", lazy="my_package.models:LazyModel")
```
### Looking up Components
```python
from lm_eval.api.registry import get_model, get_metric
# Get a model class
model_cls = get_model("gpt-j")
model = model_cls(**config)
# Get a metric function
metric_fn = get_metric("accuracy")
```
"""
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
import
importlib
import
inspect
import
threading
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
types
import
MappingProxyType
from
typing
import
Any
,
Callable
,
Generic
,
TypeVar
,
Union
,
cast
from
lm_eval.api.filter
import
Filter
try
:
import
importlib.metadata
as
md
# Python ≥3.10
except
ImportError
:
# pragma: no cover – fallback for 3.8/3.9
import
importlib_metadata
as
md
# type: ignore
LEGACY_EXPORTS
=
[
"DEFAULT_METRIC_REGISTRY"
,
"AGGREGATION_REGISTRY"
,
"register_model"
,
"get_model"
,
"register_task"
,
"get_task"
,
"register_metric"
,
"get_metric"
,
"register_metric_aggregation"
,
"get_metric_aggregation"
,
"register_higher_is_better"
,
"is_higher_better"
,
"register_filter"
,
"get_filter"
,
"register_aggregation"
,
"get_aggregation"
,
"MODEL_REGISTRY"
,
"TASK_REGISTRY"
,
"METRIC_REGISTRY"
,
"METRIC_AGGREGATION_REGISTRY"
,
"HIGHER_IS_BETTER_REGISTRY"
,
"FILTER_REGISTRY"
,
]
__all__
=
[
# canonical
"Registry"
,
"MetricSpec"
,
"model_registry"
,
"task_registry"
,
"metric_registry"
,
"metric_agg_registry"
,
"higher_is_better_registry"
,
"filter_registry"
,
"freeze_all"
,
*
LEGACY_EXPORTS
,
]
# type: ignore
T
=
TypeVar
(
"T"
)
Placeholder
=
Union
[
str
,
md
.
EntryPoint
]
@
lru_cache
(
maxsize
=
16
)
def
_materialise_placeholder
(
ph
:
Placeholder
)
->
Any
:
"""Materialize a lazy placeholder into the actual object.
This is at module level to avoid memory leaks from lru_cache on instance methods.
Args:
ph: Either a string path "module:object" or an EntryPoint instance
Returns:
The loaded object
Raises:
ValueError: If the string format is invalid
ImportError: If the module cannot be imported
AttributeError: If the object doesn't exist in the module
"""
if
isinstance
(
ph
,
str
):
mod
,
_
,
attr
=
ph
.
partition
(
":"
)
if
not
attr
:
raise
ValueError
(
f
"Invalid lazy path '
{
ph
}
', expected 'module:object'"
)
return
getattr
(
importlib
.
import_module
(
mod
),
attr
)
return
ph
.
load
()
# Metric-specific metadata storage --------------------------------------------
_metric_meta
:
dict
[
str
,
dict
[
str
,
Any
]]
=
{}
class
Registry
(
Generic
[
T
]):
"""A thread-safe registry for named objects with lazy loading support.
The Registry provides a central location for registering and retrieving
components by name. It supports:
- Direct registration of objects
- Lazy registration with placeholders (strings or entry points)
- Type checking against a base class
- Thread-safe operations
- Freezing to prevent further modifications
Example:
>>> from lm_eval.api.model import LM
>>> registry = Registry("models", base_cls=LM)
>>>
>>> # Direct registration
>>> @registry.register("my-model")
>>> class MyModel(LM):
... pass
>>>
>>> # Lazy registration
>>> registry.register("lazy-model", lazy="mypackage:LazyModel")
>>>
>>> # Retrieval (triggers lazy loading if needed)
>>> model_cls = registry.get("my-model")
>>> model = model_cls()
"""
def
__init__
(
self
,
name
:
str
,
*
,
base_cls
:
type
[
T
]
|
None
=
None
,
)
->
None
:
"""Initialize a new registry.
Args:
name: Human-readable name for error messages (e.g., "model", "metric")
base_cls: Optional base class that all registered objects must inherit from
"""
self
.
_name
=
name
self
.
_base_cls
=
base_cls
self
.
_objs
:
dict
[
str
,
T
|
Placeholder
]
=
{}
self
.
_lock
=
threading
.
RLock
()
# Registration (decorator or direct call) --------------------------------------
def
register
(
self
,
*
aliases
:
str
,
lazy
:
T
|
Placeholder
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
"""Register an object under one or more aliases.
Can be used as a decorator or called directly for lazy registration.
Args:
*aliases: Names to register the object under. If empty, uses object's __name__
lazy: For direct calls only - a placeholder string "module:object" or EntryPoint
Returns:
Decorator function (or no-op if lazy registration)
Examples:
>>> # As decorator
>>> @model_registry.register("name1", "name2")
>>> class MyModel(LM):
... pass
>>>
>>> # Direct lazy registration
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises:
ValueError: If alias already registered with different target
TypeError: If object doesn't inherit from base_cls (when specified)
"""
def
_store
(
alias
:
str
,
target
:
T
|
Placeholder
)
->
None
:
current
=
self
.
_objs
.
get
(
alias
)
# collision handling ------------------------------------------
if
current
is
not
None
and
current
!=
target
:
# allow placeholder → real object upgrade
if
isinstance
(
current
,
str
)
and
isinstance
(
target
,
type
):
# mod, _, cls = current.partition(":")
if
current
==
f
"
{
target
.
__module__
}
:
{
target
.
__name__
}
"
:
self
.
_objs
[
alias
]
=
target
return
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
alias '
{
alias
}
' already registered ("
f
"existing=
{
current
}
, new=
{
target
}
)"
)
# type check for concrete classes ----------------------------------------------
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
to be a
{
self
.
_name
}
"
)
self
.
_objs
[
alias
]
=
target
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
names
=
aliases
or
(
getattr
(
obj
,
"__name__"
,
str
(
obj
)),)
with
self
.
_lock
:
for
name
in
names
:
_store
(
name
,
obj
)
return
obj
# Direct call with *lazy* placeholder
if
lazy
is
not
None
:
if
len
(
aliases
)
!=
1
:
raise
ValueError
(
"Exactly one alias required when using 'lazy='"
)
with
self
.
_lock
:
_store
(
aliases
[
0
],
lazy
)
# type: ignore[arg-type]
# return no‑op decorator for accidental use
return
lambda
x
:
x
# type: ignore[return-value]
return
decorator
# Lookup & materialisation --------------------------------------------------
def
_materialise
(
self
,
ph
:
Placeholder
)
->
T
:
"""Materialize a placeholder using the module-level cached function.
Args:
ph: Placeholder to materialize
Returns:
The materialized object, cast to type T
"""
return
cast
(
T
,
_materialise_placeholder
(
ph
))
def
get
(
self
,
alias
:
str
)
->
T
:
"""Retrieve an object by alias, materializing if needed.
Thread-safe lazy loading: if the alias points to a placeholder,
it will be loaded and cached before returning.
Args:
alias: The registered name to look up
Returns:
The registered object
Raises:
KeyError: If alias not found
TypeError: If materialized object doesn't match base_cls
ImportError/AttributeError: If lazy loading fails
"""
try
:
target
=
self
.
_objs
[
alias
]
except
KeyError
as
exc
:
raise
KeyError
(
f
"Unknown
{
self
.
_name
}
'
{
alias
}
'. Available:
{
', '
.
join
(
self
.
_objs
)
}
"
)
from
exc
if
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
with
self
.
_lock
:
# Re‑check under lock (another thread might have resolved it)
fresh
=
self
.
_objs
[
alias
]
if
isinstance
(
fresh
,
(
str
,
md
.
EntryPoint
)):
concrete
=
self
.
_materialise
(
fresh
)
# Only update if not frozen (MappingProxyType)
if
not
isinstance
(
self
.
_objs
,
MappingProxyType
):
self
.
_objs
[
alias
]
=
concrete
else
:
concrete
=
fresh
# another thread did the job
target
=
concrete
# Late type/validator checks
if
self
.
_base_cls
is
not
None
and
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
does not inherit from
{
self
.
_base_cls
}
(alias '
{
alias
}
')"
)
return
target
if
TYPE_CHECKING
:
from
lm_eval.api.model
import
LM
def
__getitem__
(
self
,
alias
:
str
)
->
T
:
"""Allow dict-style access: registry[alias]."""
return
self
.
get
(
alias
)
eval_logger
=
logging
.
getLogger
(
__name__
)
def
__iter__
(
self
):
"""Iterate over registered aliases."""
return
iter
(
self
.
_objs
)
MODEL_REGISTRY
=
{}
DEFAULTS
=
{
"model"
:
{
"max_length"
:
2048
},
"tasks"
:
{
"generate_until"
:
{
"max_gen_toks"
:
256
}},
}
def
__len__
(
self
):
"""Return number of registered aliases."""
return
len
(
self
.
_objs
)
def
items
(
self
):
"""Return (alias, object) pairs.
def
register_model
(
*
names
):
from
lm_eval.api.model
import
LM
Note: Objects may be placeholders that haven't been materialized yet.
"""
return
self
.
_objs
.
items
()
# either pass a list or a single alias.
# function receives them as a tuple of strings
# Utilities -------------------------------------------------------------
def
decorate
(
cls
):
for
name
in
names
:
assert
issubclass
(
cls
,
LM
),
(
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
)
def
origin
(
self
,
alias
:
str
)
->
str
|
None
:
"""Get the source location of a registered object.
assert
name
not
in
MODEL_REGISTRY
,
(
f
"Model named '
{
name
}
' conflicts with existing model! Please register with a non-conflicting alias instead."
)
MODEL_REGISTRY
[
name
]
=
cls
return
cls
Args:
alias: The registered name
return
decorate
Returns:
"path/to/file.py:line_number" or None if not available
"""
obj
=
self
.
_objs
.
get
(
alias
)
if
isinstance
(
obj
,
(
str
,
md
.
EntryPoint
)):
return
None
try
:
path
=
inspect
.
getfile
(
obj
)
# type: ignore[arg-type]
line
=
inspect
.
getsourcelines
(
obj
)[
1
]
# type: ignore[arg-type]
return
f
"
{
path
}
:
{
line
}
"
except
Exception
:
# pragma: no cover – best‑effort only
return
None
def
freeze
(
self
):
"""Make the registry read-only to prevent further modifications.
def
get_model
(
model_name
:
str
)
->
type
[
LM
]:
try
:
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
as
err
:
available_models
=
", "
.
join
(
MODEL_REGISTRY
.
keys
())
raise
KeyError
(
f
"Model '
{
model_name
}
' not found. Available models:
{
available_models
}
"
)
from
err
After freezing, attempts to register new objects will fail.
This is useful for ensuring registry contents don't change after
initialization.
"""
with
self
.
_lock
:
self
.
_objs
=
MappingProxyType
(
dict
(
self
.
_objs
))
# type: ignore[assignment]
# Test helper --------------------------------
def
_clear
(
self
):
# pragma: no cover
"""Erase registry (for isolated tests).
TASK_REGISTRY
=
{}
GROUP_REGISTRY
=
{}
ALL_TASKS
=
set
()
func2task_index
=
{}
Clears both the registry contents and the materialization cache.
Only use this in test code to ensure clean state between tests.
"""
self
.
_objs
.
clear
()
_materialise_placeholder
.
cache_clear
()
def
register_task
(
name
:
str
):
def
decorate
(
fn
):
assert
name
not
in
TASK_REGISTRY
,
(
f
"task named '
{
name
}
' conflicts with existing registered task!"
# Structured object for metrics ------------------
@
dataclass
(
frozen
=
True
)
class
MetricSpec
:
"""Specification for a metric including computation and aggregation functions.
Attributes:
compute: Function to compute metric on individual items
aggregate: Function to aggregate multiple metric values into a single score
higher_is_better: Whether higher values indicate better performance
output_type: Optional type hint for the output (e.g., "generate_until" for perplexity)
requires: Optional list of other metrics this one depends on
"""
compute
:
Callable
[[
Any
,
Any
],
Any
]
aggregate
:
Callable
[[
Iterable
[
Any
]],
float
]
higher_is_better
:
bool
=
True
output_type
:
str
|
None
=
None
requires
:
list
[
str
]
|
None
=
None
# Canonical registries aliases ---------------------
from
lm_eval.api.model
import
LM
# noqa: E402
model_registry
:
Registry
[
type
[
LM
]]
=
cast
(
Registry
[
type
[
LM
]],
Registry
(
"model"
,
base_cls
=
LM
)
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
float
]]
=
Registry
(
"metric aggregation"
)
higher_is_better_registry
:
Registry
[
bool
]
=
Registry
(
"higher‑is‑better flag"
)
filter_registry
:
Registry
[
type
[
Filter
]]
=
Registry
(
"filter"
)
# Public helper aliases ------------------------------------------------------
register_model
=
model_registry
.
register
get_model
=
model_registry
.
get
register_task
=
task_registry
.
register
get_task
=
task_registry
.
get
register_filter
=
filter_registry
.
register
get_filter
=
filter_registry
.
get
# Metric helpers need thin wrappers to build MetricSpec ----------------------
def
_no_aggregation_fn
(
values
:
Iterable
[
Any
])
->
float
:
"""Default aggregation that raises NotImplementedError.
Args:
values: Metric values to aggregate (unused)
Raises:
NotImplementedError: Always - this is a placeholder for metrics
that haven't specified an aggregation function
"""
raise
NotImplementedError
(
"No aggregation function specified for this metric. "
"Please specify 'aggregation' parameter in @register_metric."
)
def
register_metric
(
**
kw
):
"""Decorator for registering metric functions.
Creates a MetricSpec from the decorated function and keyword arguments,
then registers it in the metric registry.
Args:
**kw: Keyword arguments including:
- metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True)
- output_type: Optional output type hint
- requires: Optional list of required metrics
Returns:
Decorator function that registers the metric
Example:
>>> @register_metric(
... metric="my_accuracy",
... aggregation="mean",
... higher_is_better=True
... )
... def compute_accuracy(items):
... return sum(item["correct"] for item in items) / len(items)
"""
name
=
kw
[
"metric"
]
def
deco
(
fn
):
spec
=
MetricSpec
(
compute
=
fn
,
aggregate
=
(
metric_agg_registry
.
get
(
kw
[
"aggregation"
])
if
"aggregation"
in
kw
else
_no_aggregation_fn
),
higher_is_better
=
kw
.
get
(
"higher_is_better"
,
True
),
output_type
=
kw
.
get
(
"output_type"
),
requires
=
kw
.
get
(
"requires"
),
)
TASK_REGISTRY
[
name
]
=
fn
ALL_TASKS
.
add
(
name
)
func2task_index
[
fn
.
__name__
]
=
name
metric_registry
.
register
(
name
,
lazy
=
spec
)
_metric_meta
[
name
]
=
kw
higher_is_better_registry
.
register
(
name
,
lazy
=
spec
.
higher_is_better
)
return
fn
return
deco
rate
return
deco
def
register_group
(
name
):
def
decorate
(
fn
):
func_name
=
func2task_index
[
fn
.
__name__
]
if
name
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
name
].
append
(
func_name
)
else
:
GROUP_REGISTRY
[
name
]
=
[
func_name
]
ALL_TASKS
.
add
(
name
)
return
fn
def
get_metric
(
name
,
hf_evaluate_metric
=
False
):
"""Get a metric compute function by name.
return
decorate
OUTPUT_TYPE_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
AGGREGATION_REGISTRY
:
dict
[
str
,
Callable
[[],
dict
[
str
,
Callable
]]]
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
DEFAULT_METRIC_REGISTRY
=
{
"loglikelihood"
:
[
"perplexity"
,
"acc"
,
],
"loglikelihood_rolling"
:
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
],
"multiple_choice"
:
[
"acc"
,
"acc_norm"
],
"generate_until"
:
[
"exact_match"
],
}
def
register_metric
(
**
args
):
# TODO: do we want to enforce a certain interface to registered metrics?
def
decorate
(
fn
):
assert
"metric"
in
args
name
=
args
[
"metric"
]
for
key
,
registry
in
[
(
"metric"
,
METRIC_REGISTRY
),
(
"higher_is_better"
,
HIGHER_IS_BETTER_REGISTRY
),
(
"aggregation"
,
METRIC_AGGREGATION_REGISTRY
),
]:
if
key
in
args
:
value
=
args
[
key
]
assert
value
not
in
registry
,
(
f
"
{
key
}
named '
{
value
}
' conflicts with existing registered
{
key
}
!"
)
First checks the local metric registry, then optionally falls back
to HuggingFace evaluate library.
if
key
==
"metric"
:
registry
[
name
]
=
fn
elif
key
==
"aggregation"
:
registry
[
name
]
=
AGGREGATION_REGISTRY
[
value
]
else
:
registry
[
name
]
=
value
return
fn
Args:
name: Metric name to retrieve
hf_evaluate_metric: If True, suppress warning when falling back to HF
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
[...,
Any
]
|
None
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
else
:
eval_logger
.
warning
(
f
"Could not find registered metric '
{
name
}
' in lm-eval, searching in HF Evaluate library..."
)
Returns:
The metric's compute function
Raises:
KeyError: If metric not found in registry or HF evaluate
"""
try
:
import
evaluate
as
hf_evaluate
metric_object
=
hf_evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
eval_logger
.
error
(
f
"
{
name
}
not found in the evaluate library! Please check https://huggingface.co/evaluate-metric"
,
)
def
register_aggregation
(
name
:
str
):
def
decorate
(
fn
):
assert
name
not
in
AGGREGATION_REGISTRY
,
(
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY
[
name
]
=
fn
return
fn
spec
=
metric_registry
.
get
(
name
)
return
spec
.
compute
# type: ignore[attr-defined]
except
KeyError
:
if
not
hf_evaluate_metric
:
import
logging
return
decorate
logging
.
getLogger
(
__name__
).
warning
(
f
"Metric '
{
name
}
' not in registry; trying HF evaluate…"
)
try
:
import
evaluate
as
hf
return
hf
.
load
(
name
).
compute
# type: ignore[attr-defined]
except
Exception
:
raise
KeyError
(
f
"Metric '
{
name
}
' not found anywhere"
)
def
get_aggregation
(
name
:
str
)
->
Callable
[...,
Any
]
|
None
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
register_metric_aggregation
=
metric_agg_registry
.
register
get_metric_aggregation
=
metric_agg_registry
.
get
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
dict
[
str
,
Callable
[...,
Any
]]]:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!. Using default aggregation mean"
)
return
AGGREGATION_REGISTRY
[
"mean"
]
register_higher_is_better
=
higher_is_better_registry
.
register
is_higher_better
=
higher_is_better_registry
.
get
# Legacy compatibility
register_aggregation
=
metric_agg_registry
.
register
get_aggregation
=
metric_agg_registry
.
get
DEFAULT_METRIC_REGISTRY
=
metric_registry
AGGREGATION_REGISTRY
=
metric_agg_registry
def
is_higher_better
(
metric_name
:
str
)
->
bool
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
eval_logger
.
warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!. Will default to True."
)
return
True
def
freeze_all
():
"""Freeze all registries to prevent further modifications.
def
register_filter
(
name
:
str
):
def
decorate
(
cls
):
if
name
in
FILTER_REGISTRY
:
eval_logger
.
info
(
f
"Registering filter `
{
name
}
` that is already in Registry
{
FILTER_REGISTRY
}
"
)
FILTER_REGISTRY
[
name
]
=
cls
return
cls
This is useful for ensuring registry contents are immutable after
initialization, preventing accidental modifications during runtime.
"""
for
r
in
(
model_registry
,
task_registry
,
metric_registry
,
metric_agg_registry
,
higher_is_better_registry
,
filter_registry
,
):
r
.
freeze
()
return
decorate
# Backwards‑compat aliases ----------------------------------------
def
get_filter
(
filter_name
:
str
|
Callable
)
->
Callable
:
try
:
return
FILTER_REGISTRY
[
filter_name
]
except
KeyError
as
e
:
if
callable
(
filter_name
):
return
filter_name
else
:
eval_logger
.
warning
(
f
"filter `
{
filter_name
}
` is not registered!"
)
raise
e
MODEL_REGISTRY
=
model_registry
TASK_REGISTRY
=
task_registry
METRIC_REGISTRY
=
metric_registry
METRIC_AGGREGATION_REGISTRY
=
metric_agg_registry
HIGHER_IS_BETTER_REGISTRY
=
higher_is_better_registry
FILTER_REGISTRY
=
filter_registry
lm_eval/filters/__init__.py
View file @
70314843
from
__future__
import
annotations
from
functools
import
partial
from
typing
import
Optional
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
from
lm_eval.api.registry
import
filter_registry
,
get_filter
from
.
import
custom
,
extraction
,
selection
,
transformation
def
build_filter_ensemble
(
filter_name
:
str
,
components
:
list
[
tuple
[
str
,
Optional
[
dict
[
str
,
Union
[
str
,
int
,
float
]
]]
]],
components
:
list
[
tuple
[
str
,
dict
[
str
,
str
|
int
|
float
]
|
None
]],
)
->
FilterEnsemble
:
"""
Create a filtering pipeline.
...
...
@@ -21,3 +21,12 @@ def build_filter_ensemble(
partial
(
get_filter
(
func
),
**
(
kwargs
or
{}))
for
func
,
kwargs
in
components
],
)
__all__
=
[
"custom"
,
"extraction"
,
"selection"
,
"transformation"
,
"build_filter_ensemble"
,
]
lm_eval/models/__init__.py
View file @
70314843
from
.
import
(
anthropic_llms
,
api_models
,
dummy
,
gguf
,
hf_audiolm
,
hf_steered
,
hf_vlms
,
huggingface
,
ibm_watsonx_ai
,
mamba_lm
,
nemo_lm
,
neuron_optimum
,
openai_completions
,
optimum_ipex
,
optimum_lm
,
sglang_causallms
,
sglang_generate_API
,
textsynth
,
vllm_causallms
,
vllm_vlms
,
)
# TODO: implement __all__
# Models are now lazily loaded via the registry system
# No need to import them all at once
# Define model mappings for lazy registration
MODEL_MAPPING
=
{
"anthropic-completions"
:
"lm_eval.models.anthropic_llms:AnthropicLM"
,
"anthropic-chat"
:
"lm_eval.models.anthropic_llms:AnthropicChatLM"
,
"anthropic-chat-completions"
:
"lm_eval.models.anthropic_llms:AnthropicCompletionsLM"
,
"local-completions"
:
"lm_eval.models.openai_completions:LocalCompletionsAPI"
,
"local-chat-completions"
:
"lm_eval.models.openai_completions:LocalChatCompletion"
,
"openai-completions"
:
"lm_eval.models.openai_completions:OpenAICompletionsAPI"
,
"openai-chat-completions"
:
"lm_eval.models.openai_completions:OpenAIChatCompletion"
,
"dummy"
:
"lm_eval.models.dummy:DummyLM"
,
"gguf"
:
"lm_eval.models.gguf:GGUFLM"
,
"ggml"
:
"lm_eval.models.gguf:GGUFLM"
,
"hf-audiolm-qwen"
:
"lm_eval.models.hf_audiolm:HFAudioLM"
,
"steered"
:
"lm_eval.models.hf_steered:SteeredHF"
,
"hf-multimodal"
:
"lm_eval.models.hf_vlms:HFMultimodalLM"
,
"hf-auto"
:
"lm_eval.models.huggingface:HFLM"
,
"hf"
:
"lm_eval.models.huggingface:HFLM"
,
"huggingface"
:
"lm_eval.models.huggingface:HFLM"
,
"watsonx_llm"
:
"lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI"
,
"mamba_ssm"
:
"lm_eval.models.mamba_lm:MambaLMWrapper"
,
"nemo_lm"
:
"lm_eval.models.nemo_lm:NeMoLM"
,
"neuronx"
:
"lm_eval.models.neuron_optimum:NeuronModelForCausalLM"
,
"ipex"
:
"lm_eval.models.optimum_ipex:IPEXForCausalLM"
,
"openvino"
:
"lm_eval.models.optimum_lm:OptimumLM"
,
"sglang"
:
"lm_eval.models.sglang_causallms:SGLANG"
,
"sglang-generate"
:
"lm_eval.models.sglang_generate_API:SGAPI"
,
"textsynth"
:
"lm_eval.models.textsynth:TextSynthLM"
,
"vllm"
:
"lm_eval.models.vllm_causallms:VLLM"
,
"vllm-vlm"
:
"lm_eval.models.vllm_vlms:VLLM_VLM"
,
}
# Register all models lazily
def
_register_all_models
():
"""Register all known models lazily in the registry."""
from
lm_eval.api.registry
import
model_registry
for
name
,
path
in
MODEL_MAPPING
.
items
():
# Only register if not already present (avoids conflicts when modules are imported)
if
name
not
in
model_registry
:
# Register the lazy placeholder using lazy parameter
model_registry
.
register
(
name
,
lazy
=
path
)
# Call registration on module import
_register_all_models
()
__all__
=
[
"MODEL_MAPPING"
]
try
:
...
...
lm_eval/models/hf_steered.py
View file @
70314843
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
functools
import
partial
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
from
peft.peft_model
import
PeftModel
...
...
lm_eval/models/ibm_watsonx_ai.py
View file @
70314843
...
...
@@ -3,7 +3,7 @@ import json
import
logging
import
os
import
warnings
from
functools
import
lru_
cache
from
functools
import
cache
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
tqdm
import
tqdm
...
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise
ValueError
(
error_msg
)
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
"""
Retrieves Watsonx API credentials from environmental variables.
...
...
lm_eval/models/vllm_causallms.py
View file @
70314843
...
...
@@ -42,7 +42,7 @@ try:
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
except
ModuleNotFoundError
:
p
ass
p
rint
(
"njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd"
)
if
TYPE_CHECKING
:
pass
...
...
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
View file @
70314843
...
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
...
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
View file @
70314843
...
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
...
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
pyproject.toml
View file @
70314843
...
...
@@ -121,7 +121,7 @@ lint.fixable = ["I001", "F401", "UP"]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E741"
]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
,
"F405"
]
[tool.ruff.lint.isort]
combine-as-imports
=
true
...
...
scripts/build_benchmark.py
View file @
70314843
...
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from
tqdm
import
tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registry
v2
import ALL_TASKS
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
test_registry.py
0 → 100644
View file @
70314843
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""
import
threading
import
pytest
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
(
MetricSpec
,
Registry
,
get_metric
,
metric_agg_registry
,
metric_registry
,
model_registry
,
register_metric
,
)
# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics
class
TestBasicRegistry
:
"""Test basic registry functionality."""
def
test_create_registry
(
self
):
"""Test creating a basic registry."""
reg
=
Registry
(
"test"
)
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_decorator_registration
(
self
):
"""Test decorator-based registration."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"my_class"
)
class
MyClass
:
pass
assert
"my_class"
in
reg
assert
reg
.
get
(
"my_class"
)
==
MyClass
assert
reg
[
"my_class"
]
==
MyClass
def
test_decorator_multiple_aliases
(
self
):
"""Test decorator with multiple aliases."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"alias1"
,
"alias2"
,
"alias3"
)
class
MyClass
:
pass
assert
reg
.
get
(
"alias1"
)
==
MyClass
assert
reg
.
get
(
"alias2"
)
==
MyClass
assert
reg
.
get
(
"alias3"
)
==
MyClass
def
test_decorator_auto_name
(
self
):
"""Test decorator using class name when no alias provided."""
reg
=
Registry
(
"test"
)
@
reg
.
register
()
class
AutoNamedClass
:
pass
assert
reg
.
get
(
"AutoNamedClass"
)
==
AutoNamedClass
def
test_lazy_registration
(
self
):
"""Test lazy loading with module paths."""
reg
=
Registry
(
"test"
)
# Register with lazy loading
reg
.
register
(
"join"
,
lazy
=
"os.path:join"
)
# Check it's stored as a string
assert
isinstance
(
reg
.
_objs
[
"join"
],
str
)
# Access triggers materialization
result
=
reg
.
get
(
"join"
)
import
os
assert
result
==
os
.
path
.
join
assert
callable
(
result
)
def
test_direct_registration
(
self
):
"""Test direct object registration."""
reg
=
Registry
(
"test"
)
class
DirectClass
:
pass
obj
=
DirectClass
()
reg
.
register
(
"direct"
,
lazy
=
obj
)
assert
reg
.
get
(
"direct"
)
==
obj
def
test_metadata_removed
(
self
):
"""Test that metadata parameter is removed from generic registry."""
reg
=
Registry
(
"test"
)
# Should work without metadata parameter
@
reg
.
register
(
"test_class"
)
class
TestClass
:
pass
assert
"test_class"
in
reg
assert
reg
.
get
(
"test_class"
)
==
TestClass
def
test_unknown_key_error
(
self
):
"""Test error when accessing unknown key."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
KeyError
)
as
exc_info
:
reg
.
get
(
"unknown"
)
assert
"Unknown test 'unknown'"
in
str
(
exc_info
.
value
)
assert
"Available:"
in
str
(
exc_info
.
value
)
def
test_iteration
(
self
):
"""Test registry iteration."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"a"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"b"
,
lazy
=
"os:getenv"
)
reg
.
register
(
"c"
,
lazy
=
"os:getpid"
)
assert
list
(
reg
)
==
[
"a"
,
"b"
,
"c"
]
assert
len
(
reg
)
==
3
# Test items()
items
=
list
(
reg
.
items
())
assert
len
(
items
)
==
3
assert
items
[
0
][
0
]
==
"a"
assert
isinstance
(
items
[
0
][
1
],
str
)
# Still lazy
def
test_mapping_protocol
(
self
):
"""Test that registry implements mapping protocol."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"test"
,
lazy
=
"os:getcwd"
)
# __getitem__
assert
reg
[
"test"
]
==
reg
.
get
(
"test"
)
# __contains__
assert
"test"
in
reg
assert
"missing"
not
in
reg
# __iter__ and __len__ tested above
class
TestTypeConstraints
:
"""Test type checking and base class constraints."""
def
test_base_class_constraint
(
self
):
"""Test base class validation."""
# Define a base class
class
BaseClass
:
pass
class
GoodSubclass
(
BaseClass
):
pass
class
BadClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Should work - correct subclass
@
reg
.
register
(
"good"
)
class
GoodInline
(
BaseClass
):
pass
# Should fail - wrong type
with
pytest
.
raises
(
TypeError
)
as
exc_info
:
@
reg
.
register
(
"bad"
)
class
BadInline
:
pass
assert
"must inherit from"
in
str
(
exc_info
.
value
)
def
test_lazy_type_check
(
self
):
"""Test that type checking happens on materialization for lazy entries."""
class
BaseClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Register a lazy entry that will fail type check
reg
.
register
(
"bad_lazy"
,
lazy
=
"os.path:join"
)
# Should fail when accessed - the error message varies
with
pytest
.
raises
(
TypeError
):
reg
.
get
(
"bad_lazy"
)
class
TestCollisionHandling
:
"""Test registration collision scenarios."""
def
test_identical_registration
(
self
):
"""Test that identical re-registration is allowed."""
reg
=
Registry
(
"test"
)
class
MyClass
:
pass
# First registration
reg
.
register
(
"test"
,
lazy
=
MyClass
)
# Identical re-registration should work
reg
.
register
(
"test"
,
lazy
=
MyClass
)
assert
reg
.
get
(
"test"
)
==
MyClass
def
test_different_registration_fails
(
self
):
"""Test that different re-registration fails."""
reg
=
Registry
(
"test"
)
class
Class1
:
pass
class
Class2
:
pass
reg
.
register
(
"test"
,
lazy
=
Class1
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"test"
,
lazy
=
Class2
)
assert
"already registered"
in
str
(
exc_info
.
value
)
def
test_lazy_to_concrete_upgrade
(
self
):
"""Test that lazy placeholder can be upgraded to concrete class."""
reg
=
Registry
(
"test"
)
# Register lazy
reg
.
register
(
"myclass"
,
lazy
=
"test_registry:MyUpgradeClass"
)
# Define and register concrete - should work
@
reg
.
register
(
"myclass"
)
class
MyUpgradeClass
:
pass
assert
reg
.
get
(
"myclass"
)
==
MyUpgradeClass
class
TestThreadSafety
:
"""Test thread safety of registry operations."""
def
test_concurrent_access
(
self
):
"""Test concurrent access to lazy entries."""
reg
=
Registry
(
"test"
)
# Register lazy entry
reg
.
register
(
"concurrent"
,
lazy
=
"os.path:join"
)
results
=
[]
errors
=
[]
def
access_item
():
try
:
result
=
reg
.
get
(
"concurrent"
)
results
.
append
(
result
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads
threads
=
[]
for
_
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
access_item
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
results
)
==
10
# All should get the same object
assert
all
(
r
==
results
[
0
]
for
r
in
results
)
def
test_concurrent_registration
(
self
):
"""Test concurrent registration doesn't cause issues."""
reg
=
Registry
(
"test"
)
errors
=
[]
def
register_item
(
name
,
value
):
try
:
reg
.
register
(
name
,
lazy
=
value
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads with different registrations
threads
=
[]
for
i
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
register_item
,
args
=
(
f
"item_
{
i
}
"
,
f
"module
{
i
}
:Class
{
i
}
"
)
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
reg
)
==
10
class
TestMetricRegistry
:
"""Test metric-specific registry functionality."""
def
test_metric_spec
(
self
):
"""Test MetricSpec dataclass."""
def
compute_fn
(
items
):
return
[
1
for
_
in
items
]
def
agg_fn
(
values
):
return
sum
(
values
)
/
len
(
values
)
spec
=
MetricSpec
(
compute
=
compute_fn
,
aggregate
=
agg_fn
,
higher_is_better
=
True
,
output_type
=
"probability"
,
)
assert
spec
.
compute
==
compute_fn
assert
spec
.
aggregate
==
agg_fn
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"probability"
def
test_register_metric_decorator
(
self
):
"""Test @register_metric decorator."""
# Register aggregation function first
@
metric_agg_registry
.
register
(
"test_mean"
)
def
test_mean
(
values
):
return
sum
(
values
)
/
len
(
values
)
if
values
else
0.0
# Register metric
@
register_metric
(
metric
=
"test_accuracy"
,
aggregation
=
"test_mean"
,
higher_is_better
=
True
,
output_type
=
"accuracy"
,
)
def
compute_accuracy
(
items
):
return
[
1
if
item
[
"pred"
]
==
item
[
"gold"
]
else
0
for
item
in
items
]
# Check registration
assert
"test_accuracy"
in
metric_registry
spec
=
metric_registry
.
get
(
"test_accuracy"
)
assert
isinstance
(
spec
,
MetricSpec
)
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"accuracy"
# Test compute function
items
=
[
{
"pred"
:
"a"
,
"gold"
:
"a"
},
{
"pred"
:
"b"
,
"gold"
:
"b"
},
{
"pred"
:
"c"
,
"gold"
:
"d"
},
]
result
=
spec
.
compute
(
items
)
assert
result
==
[
1
,
1
,
0
]
# Test aggregation
agg_result
=
spec
.
aggregate
(
result
)
assert
agg_result
==
2
/
3
def
test_metric_without_aggregation
(
self
):
"""Test metric registration without aggregation."""
@
register_metric
(
metric
=
"no_agg"
,
higher_is_better
=
False
)
def
compute_something
(
items
):
return
[
len
(
item
)
for
item
in
items
]
spec
=
metric_registry
.
get
(
"no_agg"
)
# Should raise NotImplementedError when aggregate is called
with
pytest
.
raises
(
NotImplementedError
)
as
exc_info
:
spec
.
aggregate
([
1
,
2
,
3
])
assert
"No aggregation function specified"
in
str
(
exc_info
.
value
)
def
test_get_metric_helper
(
self
):
"""Test get_metric helper function."""
@
register_metric
(
metric
=
"helper_test"
,
aggregation
=
"mean"
,
# Assuming 'mean' exists in metric_agg_registry
)
def
compute_helper
(
items
):
return
items
# get_metric returns just the compute function
compute_fn
=
get_metric
(
"helper_test"
)
assert
callable
(
compute_fn
)
assert
compute_fn
([
1
,
2
,
3
])
==
[
1
,
2
,
3
]
class
TestRegistryUtilities
:
"""Test utility methods."""
def
test_freeze
(
self
):
"""Test freezing a registry."""
reg
=
Registry
(
"test"
)
# Add some items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
# Freeze the registry
reg
.
freeze
()
# Should not be able to register new items
with
pytest
.
raises
(
TypeError
):
reg
.
_objs
[
"new"
]
=
"value"
# Should still be able to access items
assert
"item1"
in
reg
assert
callable
(
reg
.
get
(
"item1"
))
def
test_clear
(
self
):
"""Test clearing a registry."""
reg
=
Registry
(
"test"
)
# Add items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
assert
len
(
reg
)
==
2
# Clear
reg
.
_clear
()
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_origin
(
self
):
"""Test origin tracking."""
reg
=
Registry
(
"test"
)
# Lazy entry - no origin
reg
.
register
(
"lazy"
,
lazy
=
"os:getcwd"
)
assert
reg
.
origin
(
"lazy"
)
is
None
# Concrete class - should have origin
@
reg
.
register
(
"concrete"
)
class
ConcreteClass
:
pass
origin
=
reg
.
origin
(
"concrete"
)
assert
origin
is
not
None
assert
"test_registry.py"
in
origin
assert
":"
in
origin
# Has line number
class
TestBackwardCompatibility
:
"""Test backward compatibility features."""
def
test_model_registry_alias
(
self
):
"""Test MODEL_REGISTRY backward compatibility."""
from
lm_eval.api.registry
import
MODEL_REGISTRY
# Should be same object as model_registry
assert
MODEL_REGISTRY
is
model_registry
# Should reflect current state
before_count
=
len
(
MODEL_REGISTRY
)
# Add new model
@
model_registry
.
register
(
"test_model_compat"
)
class
TestModelCompat
(
LM
):
pass
# MODEL_REGISTRY should immediately reflect the change
assert
len
(
MODEL_REGISTRY
)
==
before_count
+
1
assert
"test_model_compat"
in
MODEL_REGISTRY
def
test_legacy_functions
(
self
):
"""Test legacy helper functions."""
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
get_model
,
register_model
,
)
# register_model should work
@
register_model
(
"legacy_model"
)
class
LegacyModel
(
LM
):
pass
# get_model should work
assert
get_model
(
"legacy_model"
)
==
LegacyModel
# Check other aliases
assert
DEFAULT_METRIC_REGISTRY
is
metric_registry
assert
AGGREGATION_REGISTRY
is
metric_agg_registry
class
TestEdgeCases
:
"""Test edge cases and error conditions."""
def
test_invalid_lazy_format
(
self
):
"""Test error on invalid lazy format."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"bad"
,
lazy
=
"no_colon_here"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
get
(
"bad"
)
assert
"expected 'module:object'"
in
str
(
exc_info
.
value
)
def
test_lazy_module_not_found
(
self
):
"""Test error when lazy module doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing"
,
lazy
=
"nonexistent_module:Class"
)
with
pytest
.
raises
(
ModuleNotFoundError
):
reg
.
get
(
"missing"
)
def
test_lazy_attribute_not_found
(
self
):
"""Test error when lazy attribute doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing_attr"
,
lazy
=
"os:nonexistent_function"
)
with
pytest
.
raises
(
AttributeError
):
reg
.
get
(
"missing_attr"
)
def
test_multiple_aliases_with_lazy
(
self
):
"""Test that multiple aliases with lazy fails."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"alias1"
,
"alias2"
,
lazy
=
"os:getcwd"
)
assert
"Exactly one alias required"
in
str
(
exc_info
.
value
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment