Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
70314843
Unverified
Commit
70314843
authored
Sep 26, 2025
by
Baber Abbasi
Committed by
GitHub
Sep 26, 2025
Browse files
Merge pull request #3189 from EleutherAI/lazy_reg
refactor registry
parents
73202a2e
930b4253
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1165 additions
and
210 deletions
+1165
-210
lm_eval/__init__.py
lm_eval/__init__.py
+4
-0
lm_eval/api/registry.py
lm_eval/api/registry.py
+532
-170
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+12
-3
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+51
-25
lm_eval/models/hf_steered.py
lm_eval/models/hf_steered.py
+2
-1
lm_eval/models/ibm_watsonx_ai.py
lm_eval/models/ibm_watsonx_ai.py
+2
-2
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+1
-1
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
+3
-3
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
+3
-3
pyproject.toml
pyproject.toml
+1
-1
scripts/build_benchmark.py
scripts/build_benchmark.py
+1
-1
test_registry.py
test_registry.py
+553
-0
No files found.
lm_eval/__init__.py
View file @
70314843
from
.api
import
metrics
,
model
,
registry
# initializes the registries
from
.filters
import
*
__version__
=
"0.4.9.1"
__version__
=
"0.4.9.1"
...
...
lm_eval/api/registry.py
View file @
70314843
"""Registry system for lm_eval components.
This module provides a centralized registration system for models, tasks, metrics,
filters, and other components in the lm_eval framework. The registry supports:
- Lazy loading with placeholders to improve startup time
- Type checking and validation
- Thread-safe registration and lookup
- Plugin discovery via entry points
- Backwards compatibility with legacy registration patterns
## Usage Examples
### Registering a Model
```python
from lm_eval.api.registry import register_model
from lm_eval.api.model import LM
@register_model("my-model")
class MyModel(LM):
def __init__(self, **kwargs):
...
```
### Registering a Metric
```python
from lm_eval.api.registry import register_metric
@register_metric(
metric="my_accuracy",
aggregation="mean",
higher_is_better=True
)
def my_accuracy_fn(items):
...
```
### Registering with Lazy Loading
```python
# Register without importing the actual implementation
model_registry.register("lazy-model", lazy="my_package.models:LazyModel")
```
### Looking up Components
```python
from lm_eval.api.registry import get_model, get_metric
# Get a model class
model_cls = get_model("gpt-j")
model = model_cls(**config)
# Get a metric function
metric_fn = get_metric("accuracy")
```
"""
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
importlib
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
import
inspect
import
threading
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
types
import
MappingProxyType
from
typing
import
Any
,
Callable
,
Generic
,
TypeVar
,
Union
,
cast
from
lm_eval.api.filter
import
Filter
try
:
import
importlib.metadata
as
md
# Python ≥3.10
except
ImportError
:
# pragma: no cover – fallback for 3.8/3.9
import
importlib_metadata
as
md
# type: ignore
LEGACY_EXPORTS
=
[
"DEFAULT_METRIC_REGISTRY"
,
"AGGREGATION_REGISTRY"
,
"register_model"
,
"get_model"
,
"register_task"
,
"get_task"
,
"register_metric"
,
"get_metric"
,
"register_metric_aggregation"
,
"get_metric_aggregation"
,
"register_higher_is_better"
,
"is_higher_better"
,
"register_filter"
,
"get_filter"
,
"register_aggregation"
,
"get_aggregation"
,
"MODEL_REGISTRY"
,
"TASK_REGISTRY"
,
"METRIC_REGISTRY"
,
"METRIC_AGGREGATION_REGISTRY"
,
"HIGHER_IS_BETTER_REGISTRY"
,
"FILTER_REGISTRY"
,
]
__all__
=
[
# canonical
"Registry"
,
"MetricSpec"
,
"model_registry"
,
"task_registry"
,
"metric_registry"
,
"metric_agg_registry"
,
"higher_is_better_registry"
,
"filter_registry"
,
"freeze_all"
,
*
LEGACY_EXPORTS
,
]
# type: ignore
T
=
TypeVar
(
"T"
)
Placeholder
=
Union
[
str
,
md
.
EntryPoint
]
@
lru_cache
(
maxsize
=
16
)
def
_materialise_placeholder
(
ph
:
Placeholder
)
->
Any
:
"""Materialize a lazy placeholder into the actual object.
This is at module level to avoid memory leaks from lru_cache on instance methods.
Args:
ph: Either a string path "module:object" or an EntryPoint instance
Returns:
The loaded object
Raises:
ValueError: If the string format is invalid
ImportError: If the module cannot be imported
AttributeError: If the object doesn't exist in the module
"""
if
isinstance
(
ph
,
str
):
mod
,
_
,
attr
=
ph
.
partition
(
":"
)
if
not
attr
:
raise
ValueError
(
f
"Invalid lazy path '
{
ph
}
', expected 'module:object'"
)
return
getattr
(
importlib
.
import_module
(
mod
),
attr
)
return
ph
.
load
()
# Metric-specific metadata storage --------------------------------------------
_metric_meta
:
dict
[
str
,
dict
[
str
,
Any
]]
=
{}
class
Registry
(
Generic
[
T
]):
"""A thread-safe registry for named objects with lazy loading support.
The Registry provides a central location for registering and retrieving
components by name. It supports:
- Direct registration of objects
- Lazy registration with placeholders (strings or entry points)
- Type checking against a base class
- Thread-safe operations
- Freezing to prevent further modifications
Example:
>>> from lm_eval.api.model import LM
>>> registry = Registry("models", base_cls=LM)
>>>
>>> # Direct registration
>>> @registry.register("my-model")
>>> class MyModel(LM):
... pass
>>>
>>> # Lazy registration
>>> registry.register("lazy-model", lazy="mypackage:LazyModel")
>>>
>>> # Retrieval (triggers lazy loading if needed)
>>> model_cls = registry.get("my-model")
>>> model = model_cls()
"""
def
__init__
(
self
,
name
:
str
,
*
,
base_cls
:
type
[
T
]
|
None
=
None
,
)
->
None
:
"""Initialize a new registry.
Args:
name: Human-readable name for error messages (e.g., "model", "metric")
base_cls: Optional base class that all registered objects must inherit from
"""
self
.
_name
=
name
self
.
_base_cls
=
base_cls
self
.
_objs
:
dict
[
str
,
T
|
Placeholder
]
=
{}
self
.
_lock
=
threading
.
RLock
()
# Registration (decorator or direct call) --------------------------------------
def
register
(
self
,
*
aliases
:
str
,
lazy
:
T
|
Placeholder
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
"""Register an object under one or more aliases.
Can be used as a decorator or called directly for lazy registration.
Args:
*aliases: Names to register the object under. If empty, uses object's __name__
lazy: For direct calls only - a placeholder string "module:object" or EntryPoint
Returns:
Decorator function (or no-op if lazy registration)
Examples:
>>> # As decorator
>>> @model_registry.register("name1", "name2")
>>> class MyModel(LM):
... pass
>>>
>>> # Direct lazy registration
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises:
ValueError: If alias already registered with different target
TypeError: If object doesn't inherit from base_cls (when specified)
"""
def
_store
(
alias
:
str
,
target
:
T
|
Placeholder
)
->
None
:
current
=
self
.
_objs
.
get
(
alias
)
# collision handling ------------------------------------------
if
current
is
not
None
and
current
!=
target
:
# allow placeholder → real object upgrade
if
isinstance
(
current
,
str
)
and
isinstance
(
target
,
type
):
# mod, _, cls = current.partition(":")
if
current
==
f
"
{
target
.
__module__
}
:
{
target
.
__name__
}
"
:
self
.
_objs
[
alias
]
=
target
return
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
alias '
{
alias
}
' already registered ("
f
"existing=
{
current
}
, new=
{
target
}
)"
)
# type check for concrete classes ----------------------------------------------
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
to be a
{
self
.
_name
}
"
)
self
.
_objs
[
alias
]
=
target
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
names
=
aliases
or
(
getattr
(
obj
,
"__name__"
,
str
(
obj
)),)
with
self
.
_lock
:
for
name
in
names
:
_store
(
name
,
obj
)
return
obj
# Direct call with *lazy* placeholder
if
lazy
is
not
None
:
if
len
(
aliases
)
!=
1
:
raise
ValueError
(
"Exactly one alias required when using 'lazy='"
)
with
self
.
_lock
:
_store
(
aliases
[
0
],
lazy
)
# type: ignore[arg-type]
# return no‑op decorator for accidental use
return
lambda
x
:
x
# type: ignore[return-value]
return
decorator
# Lookup & materialisation --------------------------------------------------
def
_materialise
(
self
,
ph
:
Placeholder
)
->
T
:
"""Materialize a placeholder using the module-level cached function.
Args:
ph: Placeholder to materialize
Returns:
The materialized object, cast to type T
"""
return
cast
(
T
,
_materialise_placeholder
(
ph
))
def
get
(
self
,
alias
:
str
)
->
T
:
"""Retrieve an object by alias, materializing if needed.
Thread-safe lazy loading: if the alias points to a placeholder,
it will be loaded and cached before returning.
Args:
alias: The registered name to look up
Returns:
The registered object
Raises:
KeyError: If alias not found
TypeError: If materialized object doesn't match base_cls
ImportError/AttributeError: If lazy loading fails
"""
try
:
target
=
self
.
_objs
[
alias
]
except
KeyError
as
exc
:
raise
KeyError
(
f
"Unknown
{
self
.
_name
}
'
{
alias
}
'. Available:
{
', '
.
join
(
self
.
_objs
)
}
"
)
from
exc
if
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
with
self
.
_lock
:
# Re‑check under lock (another thread might have resolved it)
fresh
=
self
.
_objs
[
alias
]
if
isinstance
(
fresh
,
(
str
,
md
.
EntryPoint
)):
concrete
=
self
.
_materialise
(
fresh
)
# Only update if not frozen (MappingProxyType)
if
not
isinstance
(
self
.
_objs
,
MappingProxyType
):
self
.
_objs
[
alias
]
=
concrete
else
:
concrete
=
fresh
# another thread did the job
target
=
concrete
# Late type/validator checks
if
self
.
_base_cls
is
not
None
and
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
does not inherit from
{
self
.
_base_cls
}
(alias '
{
alias
}
')"
)
return
target
if
TYPE_CHECKING
:
def
__getitem__
(
self
,
alias
:
str
)
->
T
:
from
lm_eval.api.model
import
LM
"""Allow dict-style access: registry[alias]."""
return
self
.
get
(
alias
)
eval_logger
=
logging
.
getLogger
(
__name__
)
def
__iter__
(
self
):
"""Iterate over registered aliases."""
return
iter
(
self
.
_objs
)
MODEL_REGISTRY
=
{}
def
__len__
(
self
):
DEFAULTS
=
{
"""Return number of registered aliases."""
"model"
:
{
"max_length"
:
2048
},
return
len
(
self
.
_objs
)
"tasks"
:
{
"generate_until"
:
{
"max_gen_toks"
:
256
}},
}
def
items
(
self
):
"""Return (alias, object) pairs.
def
register_model
(
*
names
):
Note: Objects may be placeholders that haven't been materialized yet.
from
lm_eval.api.model
import
LM
"""
return
self
.
_objs
.
items
()
# either pass a list or a single alias.
# Utilities -------------------------------------------------------------
# function receives them as a tuple of strings
def
decorate
(
cls
):
def
origin
(
self
,
alias
:
str
)
->
str
|
None
:
for
name
in
names
:
"""Get the source location of a registered object.
assert
issubclass
(
cls
,
LM
),
(
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
)
assert
name
not
in
MODEL_REGISTRY
,
(
Args:
f
"Model named '
{
name
}
' conflicts with existing model! Please register with a non-conflicting alias instead."
alias: The registered name
)
MODEL_REGISTRY
[
name
]
=
cls
return
cls
return
decorate
Returns:
"path/to/file.py:line_number" or None if not available
"""
obj
=
self
.
_objs
.
get
(
alias
)
if
isinstance
(
obj
,
(
str
,
md
.
EntryPoint
)):
return
None
try
:
path
=
inspect
.
getfile
(
obj
)
# type: ignore[arg-type]
line
=
inspect
.
getsourcelines
(
obj
)[
1
]
# type: ignore[arg-type]
return
f
"
{
path
}
:
{
line
}
"
except
Exception
:
# pragma: no cover – best‑effort only
return
None
def
freeze
(
self
):
"""Make the registry read-only to prevent further modifications.
def
get_model
(
model_name
:
str
)
->
type
[
LM
]:
After freezing, attempts to register new objects will fail.
try
:
This is useful for ensuring registry contents don't change after
return
MODEL_REGISTRY
[
model_name
]
initialization.
except
KeyError
as
err
:
"""
available_models
=
", "
.
join
(
MODEL_REGISTRY
.
keys
())
with
self
.
_lock
:
raise
KeyError
(
self
.
_objs
=
MappingProxyType
(
dict
(
self
.
_objs
))
# type: ignore[assignment]
f
"Model '
{
model_name
}
' not found. Available models:
{
available_models
}
"
)
from
err
# Test helper --------------------------------
def
_clear
(
self
):
# pragma: no cover
"""Erase registry (for isolated tests).
TASK_REGISTRY
=
{}
Clears both the registry contents and the materialization cache.
GROUP_REGISTRY
=
{}
Only use this in test code to ensure clean state between tests.
ALL_TASKS
=
set
()
"""
func2task_index
=
{}
self
.
_objs
.
clear
()
_materialise_placeholder
.
cache_clear
()
def
register_task
(
name
:
str
):
# Structured object for metrics ------------------
def
decorate
(
fn
):
assert
name
not
in
TASK_REGISTRY
,
(
f
"task named '
{
name
}
' conflicts with existing registered task!"
@
dataclass
(
frozen
=
True
)
class
MetricSpec
:
"""Specification for a metric including computation and aggregation functions.
Attributes:
compute: Function to compute metric on individual items
aggregate: Function to aggregate multiple metric values into a single score
higher_is_better: Whether higher values indicate better performance
output_type: Optional type hint for the output (e.g., "generate_until" for perplexity)
requires: Optional list of other metrics this one depends on
"""
compute
:
Callable
[[
Any
,
Any
],
Any
]
aggregate
:
Callable
[[
Iterable
[
Any
]],
float
]
higher_is_better
:
bool
=
True
output_type
:
str
|
None
=
None
requires
:
list
[
str
]
|
None
=
None
# Canonical registries aliases ---------------------
from
lm_eval.api.model
import
LM
# noqa: E402
model_registry
:
Registry
[
type
[
LM
]]
=
cast
(
Registry
[
type
[
LM
]],
Registry
(
"model"
,
base_cls
=
LM
)
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
float
]]
=
Registry
(
"metric aggregation"
)
higher_is_better_registry
:
Registry
[
bool
]
=
Registry
(
"higher‑is‑better flag"
)
filter_registry
:
Registry
[
type
[
Filter
]]
=
Registry
(
"filter"
)
# Public helper aliases ------------------------------------------------------
register_model
=
model_registry
.
register
get_model
=
model_registry
.
get
register_task
=
task_registry
.
register
get_task
=
task_registry
.
get
register_filter
=
filter_registry
.
register
get_filter
=
filter_registry
.
get
# Metric helpers need thin wrappers to build MetricSpec ----------------------
def
_no_aggregation_fn
(
values
:
Iterable
[
Any
])
->
float
:
"""Default aggregation that raises NotImplementedError.
Args:
values: Metric values to aggregate (unused)
Raises:
NotImplementedError: Always - this is a placeholder for metrics
that haven't specified an aggregation function
"""
raise
NotImplementedError
(
"No aggregation function specified for this metric. "
"Please specify 'aggregation' parameter in @register_metric."
)
def
register_metric
(
**
kw
):
"""Decorator for registering metric functions.
Creates a MetricSpec from the decorated function and keyword arguments,
then registers it in the metric registry.
Args:
**kw: Keyword arguments including:
- metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True)
- output_type: Optional output type hint
- requires: Optional list of required metrics
Returns:
Decorator function that registers the metric
Example:
>>> @register_metric(
... metric="my_accuracy",
... aggregation="mean",
... higher_is_better=True
... )
... def compute_accuracy(items):
... return sum(item["correct"] for item in items) / len(items)
"""
name
=
kw
[
"metric"
]
def
deco
(
fn
):
spec
=
MetricSpec
(
compute
=
fn
,
aggregate
=
(
metric_agg_registry
.
get
(
kw
[
"aggregation"
])
if
"aggregation"
in
kw
else
_no_aggregation_fn
),
higher_is_better
=
kw
.
get
(
"higher_is_better"
,
True
),
output_type
=
kw
.
get
(
"output_type"
),
requires
=
kw
.
get
(
"requires"
),
)
)
metric_registry
.
register
(
name
,
lazy
=
spec
)
TASK_REGISTRY
[
name
]
=
fn
_metric_meta
[
name
]
=
kw
ALL_TASKS
.
add
(
name
)
higher_is_better_registry
.
register
(
name
,
lazy
=
spec
.
higher_is_better
)
func2task_index
[
fn
.
__name__
]
=
name
return
fn
return
fn
return
deco
rate
return
deco
def
register_group
(
name
):
def
get_metric
(
name
,
hf_evaluate_metric
=
False
):
def
decorate
(
fn
):
"""Get a metric compute function by name.
func_name
=
func2task_index
[
fn
.
__name__
]
if
name
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
name
].
append
(
func_name
)
else
:
GROUP_REGISTRY
[
name
]
=
[
func_name
]
ALL_TASKS
.
add
(
name
)
return
fn
return
decorate
First checks the local metric registry, then optionally falls back
to HuggingFace evaluate library.
OUTPUT_TYPE_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
AGGREGATION_REGISTRY
:
dict
[
str
,
Callable
[[],
dict
[
str
,
Callable
]]]
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
DEFAULT_METRIC_REGISTRY
=
{
"loglikelihood"
:
[
"perplexity"
,
"acc"
,
],
"loglikelihood_rolling"
:
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
],
"multiple_choice"
:
[
"acc"
,
"acc_norm"
],
"generate_until"
:
[
"exact_match"
],
}
def
register_metric
(
**
args
):
# TODO: do we want to enforce a certain interface to registered metrics?
def
decorate
(
fn
):
assert
"metric"
in
args
name
=
args
[
"metric"
]
for
key
,
registry
in
[
(
"metric"
,
METRIC_REGISTRY
),
(
"higher_is_better"
,
HIGHER_IS_BETTER_REGISTRY
),
(
"aggregation"
,
METRIC_AGGREGATION_REGISTRY
),
]:
if
key
in
args
:
value
=
args
[
key
]
assert
value
not
in
registry
,
(
f
"
{
key
}
named '
{
value
}
' conflicts with existing registered
{
key
}
!"
)
if
key
==
"metric"
:
Args:
registry
[
name
]
=
fn
name: Metric name to retrieve
elif
key
==
"aggregation"
:
hf_evaluate_metric: If True, suppress warning when falling back to HF
registry
[
name
]
=
AGGREGATION_REGISTRY
[
value
]
else
:
registry
[
name
]
=
value
return
fn
return
decorate
Returns:
The metric's compute function
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
[...,
Any
]
|
None
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
else
:
eval_logger
.
warning
(
f
"Could not find registered metric '
{
name
}
' in lm-eval, searching in HF Evaluate library..."
)
Raises:
KeyError: If metric not found in registry or HF evaluate
"""
try
:
try
:
import
evaluate
as
hf_evaluate
spec
=
metric_registry
.
get
(
name
)
return
spec
.
compute
# type: ignore[attr-defined]
metric_object
=
hf_evaluate
.
load
(
name
)
except
KeyError
:
return
metric_object
.
compute
if
not
hf_evaluate_metric
:
except
Exception
:
import
logging
eval_logger
.
error
(
f
"
{
name
}
not found in the evaluate library! Please check https://huggingface.co/evaluate-metric"
,
)
def
register_aggregation
(
name
:
str
):
def
decorate
(
fn
):
assert
name
not
in
AGGREGATION_REGISTRY
,
(
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
logging
.
getLogger
(
__name__
).
warning
(
f
"Metric '
{
name
}
' not in registry; trying HF evaluate…"
)
try
:
import
evaluate
as
hf
return
hf
.
load
(
name
).
compute
# type: ignore[attr-defined]
except
Exception
:
raise
KeyError
(
f
"Metric '
{
name
}
' not found anywhere"
)
def
get_aggregation
(
name
:
str
)
->
Callable
[...,
Any
]
|
None
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
register_metric_aggregation
=
metric_agg_registry
.
register
get_metric_aggregation
=
metric_agg_registry
.
get
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
dict
[
str
,
Callable
[...,
Any
]]]:
register_higher_is_better
=
higher_is_better_registry
.
register
try
:
is_higher_better
=
higher_is_better_registry
.
get
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!. Using default aggregation mean"
)
return
AGGREGATION_REGISTRY
[
"mean"
]
# Legacy compatibility
register_aggregation
=
metric_agg_registry
.
register
get_aggregation
=
metric_agg_registry
.
get
DEFAULT_METRIC_REGISTRY
=
metric_registry
AGGREGATION_REGISTRY
=
metric_agg_registry
def
is_higher_better
(
metric_name
:
str
)
->
bool
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
eval_logger
.
warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!. Will default to True."
)
return
True
def
freeze_all
():
"""Freeze all registries to prevent further modifications.
def
register_filter
(
name
:
str
):
This is useful for ensuring registry contents are immutable after
def
decorate
(
cls
):
initialization, preventing accidental modifications during runtime.
if
name
in
FILTER_REGISTRY
:
"""
eval_logger
.
info
(
for
r
in
(
f
"Registering filter `
{
name
}
` that is already in Registry
{
FILTER_REGISTRY
}
"
model_registry
,
)
task_registry
,
FILTER_REGISTRY
[
name
]
=
cls
metric_registry
,
return
cls
metric_agg_registry
,
higher_is_better_registry
,
filter_registry
,
):
r
.
freeze
()
return
decorate
# Backwards‑compat aliases ----------------------------------------
def
get_filter
(
filter_name
:
str
|
Callable
)
->
Callable
:
MODEL_REGISTRY
=
model_registry
try
:
TASK_REGISTRY
=
task_registry
return
FILTER_REGISTRY
[
filter_name
]
METRIC_REGISTRY
=
metric_registry
except
KeyError
as
e
:
METRIC_AGGREGATION_REGISTRY
=
metric_agg_registry
if
callable
(
filter_name
):
HIGHER_IS_BETTER_REGISTRY
=
higher_is_better_registry
return
filter_name
FILTER_REGISTRY
=
filter_registry
else
:
eval_logger
.
warning
(
f
"filter `
{
filter_name
}
` is not registered!"
)
raise
e
lm_eval/filters/__init__.py
View file @
70314843
from
__future__
import
annotations
from
functools
import
partial
from
functools
import
partial
from
typing
import
Optional
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
from
lm_eval.api.registry
import
filter_registry
,
get_filter
from
.
import
custom
,
extraction
,
selection
,
transformation
from
.
import
custom
,
extraction
,
selection
,
transformation
def
build_filter_ensemble
(
def
build_filter_ensemble
(
filter_name
:
str
,
filter_name
:
str
,
components
:
list
[
tuple
[
str
,
Optional
[
dict
[
str
,
Union
[
str
,
int
,
float
]
]]
]],
components
:
list
[
tuple
[
str
,
dict
[
str
,
str
|
int
|
float
]
|
None
]],
)
->
FilterEnsemble
:
)
->
FilterEnsemble
:
"""
"""
Create a filtering pipeline.
Create a filtering pipeline.
...
@@ -21,3 +21,12 @@ def build_filter_ensemble(
...
@@ -21,3 +21,12 @@ def build_filter_ensemble(
partial
(
get_filter
(
func
),
**
(
kwargs
or
{}))
for
func
,
kwargs
in
components
partial
(
get_filter
(
func
),
**
(
kwargs
or
{}))
for
func
,
kwargs
in
components
],
],
)
)
__all__
=
[
"custom"
,
"extraction"
,
"selection"
,
"transformation"
,
"build_filter_ensemble"
,
]
lm_eval/models/__init__.py
View file @
70314843
from
.
import
(
# Models are now lazily loaded via the registry system
anthropic_llms
,
# No need to import them all at once
api_models
,
dummy
,
# Define model mappings for lazy registration
gguf
,
MODEL_MAPPING
=
{
hf_audiolm
,
"anthropic-completions"
:
"lm_eval.models.anthropic_llms:AnthropicLM"
,
hf_steered
,
"anthropic-chat"
:
"lm_eval.models.anthropic_llms:AnthropicChatLM"
,
hf_vlms
,
"anthropic-chat-completions"
:
"lm_eval.models.anthropic_llms:AnthropicCompletionsLM"
,
huggingface
,
"local-completions"
:
"lm_eval.models.openai_completions:LocalCompletionsAPI"
,
ibm_watsonx_ai
,
"local-chat-completions"
:
"lm_eval.models.openai_completions:LocalChatCompletion"
,
mamba_lm
,
"openai-completions"
:
"lm_eval.models.openai_completions:OpenAICompletionsAPI"
,
nemo_lm
,
"openai-chat-completions"
:
"lm_eval.models.openai_completions:OpenAIChatCompletion"
,
neuron_optimum
,
"dummy"
:
"lm_eval.models.dummy:DummyLM"
,
openai_completions
,
"gguf"
:
"lm_eval.models.gguf:GGUFLM"
,
optimum_ipex
,
"ggml"
:
"lm_eval.models.gguf:GGUFLM"
,
optimum_lm
,
"hf-audiolm-qwen"
:
"lm_eval.models.hf_audiolm:HFAudioLM"
,
sglang_causallms
,
"steered"
:
"lm_eval.models.hf_steered:SteeredHF"
,
sglang_generate_API
,
"hf-multimodal"
:
"lm_eval.models.hf_vlms:HFMultimodalLM"
,
textsynth
,
"hf-auto"
:
"lm_eval.models.huggingface:HFLM"
,
vllm_causallms
,
"hf"
:
"lm_eval.models.huggingface:HFLM"
,
vllm_vlms
,
"huggingface"
:
"lm_eval.models.huggingface:HFLM"
,
)
"watsonx_llm"
:
"lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI"
,
"mamba_ssm"
:
"lm_eval.models.mamba_lm:MambaLMWrapper"
,
"nemo_lm"
:
"lm_eval.models.nemo_lm:NeMoLM"
,
# TODO: implement __all__
"neuronx"
:
"lm_eval.models.neuron_optimum:NeuronModelForCausalLM"
,
"ipex"
:
"lm_eval.models.optimum_ipex:IPEXForCausalLM"
,
"openvino"
:
"lm_eval.models.optimum_lm:OptimumLM"
,
"sglang"
:
"lm_eval.models.sglang_causallms:SGLANG"
,
"sglang-generate"
:
"lm_eval.models.sglang_generate_API:SGAPI"
,
"textsynth"
:
"lm_eval.models.textsynth:TextSynthLM"
,
"vllm"
:
"lm_eval.models.vllm_causallms:VLLM"
,
"vllm-vlm"
:
"lm_eval.models.vllm_vlms:VLLM_VLM"
,
}
# Register all models lazily
def
_register_all_models
():
"""Register all known models lazily in the registry."""
from
lm_eval.api.registry
import
model_registry
for
name
,
path
in
MODEL_MAPPING
.
items
():
# Only register if not already present (avoids conflicts when modules are imported)
if
name
not
in
model_registry
:
# Register the lazy placeholder using lazy parameter
model_registry
.
register
(
name
,
lazy
=
path
)
# Call registration on module import
_register_all_models
()
__all__
=
[
"MODEL_MAPPING"
]
try
:
try
:
...
...
lm_eval/models/hf_steered.py
View file @
70314843
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
from
peft.peft_model
import
PeftModel
from
peft.peft_model
import
PeftModel
...
...
lm_eval/models/ibm_watsonx_ai.py
View file @
70314843
...
@@ -3,7 +3,7 @@ import json
...
@@ -3,7 +3,7 @@ import json
import
logging
import
logging
import
os
import
os
import
warnings
import
warnings
from
functools
import
lru_
cache
from
functools
import
cache
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise
ValueError
(
error_msg
)
raise
ValueError
(
error_msg
)
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
"""
"""
Retrieves Watsonx API credentials from environmental variables.
Retrieves Watsonx API credentials from environmental variables.
...
...
lm_eval/models/vllm_causallms.py
View file @
70314843
...
@@ -42,7 +42,7 @@ try:
...
@@ -42,7 +42,7 @@ try:
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
p
ass
p
rint
(
"njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd"
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
pass
pass
...
...
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
View file @
70314843
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
with
open
(
GRAMMAR_FILE
)
as
f
:
...
@@ -556,8 +556,8 @@ class STRIPS:
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
View file @
70314843
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
with
open
(
GRAMMAR_FILE
)
as
f
:
...
@@ -556,8 +556,8 @@ class STRIPS:
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
pyproject.toml
View file @
70314843
...
@@ -121,7 +121,7 @@ lint.fixable = ["I001", "F401", "UP"]
...
@@ -121,7 +121,7 @@ lint.fixable = ["I001", "F401", "UP"]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E741"
]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E741"
]
[tool.ruff.lint.extend-per-file-ignores]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
]
"__init__.py"
=
[
"F401"
,
"F402"
,
"F403"
,
"F405"
]
[tool.ruff.lint.isort]
[tool.ruff.lint.isort]
combine-as-imports
=
true
combine-as-imports
=
true
...
...
scripts/build_benchmark.py
View file @
70314843
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from
tqdm
import
tqdm
from
tqdm
import
tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registry
v2
import ALL_TASKS
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
test_registry.py
0 → 100644
View file @
70314843
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""
import
threading
import
pytest
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
(
MetricSpec
,
Registry
,
get_metric
,
metric_agg_registry
,
metric_registry
,
model_registry
,
register_metric
,
)
# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics
class
TestBasicRegistry
:
"""Test basic registry functionality."""
def
test_create_registry
(
self
):
"""Test creating a basic registry."""
reg
=
Registry
(
"test"
)
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_decorator_registration
(
self
):
"""Test decorator-based registration."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"my_class"
)
class
MyClass
:
pass
assert
"my_class"
in
reg
assert
reg
.
get
(
"my_class"
)
==
MyClass
assert
reg
[
"my_class"
]
==
MyClass
def
test_decorator_multiple_aliases
(
self
):
"""Test decorator with multiple aliases."""
reg
=
Registry
(
"test"
)
@
reg
.
register
(
"alias1"
,
"alias2"
,
"alias3"
)
class
MyClass
:
pass
assert
reg
.
get
(
"alias1"
)
==
MyClass
assert
reg
.
get
(
"alias2"
)
==
MyClass
assert
reg
.
get
(
"alias3"
)
==
MyClass
def
test_decorator_auto_name
(
self
):
"""Test decorator using class name when no alias provided."""
reg
=
Registry
(
"test"
)
@
reg
.
register
()
class
AutoNamedClass
:
pass
assert
reg
.
get
(
"AutoNamedClass"
)
==
AutoNamedClass
def
test_lazy_registration
(
self
):
"""Test lazy loading with module paths."""
reg
=
Registry
(
"test"
)
# Register with lazy loading
reg
.
register
(
"join"
,
lazy
=
"os.path:join"
)
# Check it's stored as a string
assert
isinstance
(
reg
.
_objs
[
"join"
],
str
)
# Access triggers materialization
result
=
reg
.
get
(
"join"
)
import
os
assert
result
==
os
.
path
.
join
assert
callable
(
result
)
def
test_direct_registration
(
self
):
"""Test direct object registration."""
reg
=
Registry
(
"test"
)
class
DirectClass
:
pass
obj
=
DirectClass
()
reg
.
register
(
"direct"
,
lazy
=
obj
)
assert
reg
.
get
(
"direct"
)
==
obj
def
test_metadata_removed
(
self
):
"""Test that metadata parameter is removed from generic registry."""
reg
=
Registry
(
"test"
)
# Should work without metadata parameter
@
reg
.
register
(
"test_class"
)
class
TestClass
:
pass
assert
"test_class"
in
reg
assert
reg
.
get
(
"test_class"
)
==
TestClass
def
test_unknown_key_error
(
self
):
"""Test error when accessing unknown key."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
KeyError
)
as
exc_info
:
reg
.
get
(
"unknown"
)
assert
"Unknown test 'unknown'"
in
str
(
exc_info
.
value
)
assert
"Available:"
in
str
(
exc_info
.
value
)
def
test_iteration
(
self
):
"""Test registry iteration."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"a"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"b"
,
lazy
=
"os:getenv"
)
reg
.
register
(
"c"
,
lazy
=
"os:getpid"
)
assert
list
(
reg
)
==
[
"a"
,
"b"
,
"c"
]
assert
len
(
reg
)
==
3
# Test items()
items
=
list
(
reg
.
items
())
assert
len
(
items
)
==
3
assert
items
[
0
][
0
]
==
"a"
assert
isinstance
(
items
[
0
][
1
],
str
)
# Still lazy
def
test_mapping_protocol
(
self
):
"""Test that registry implements mapping protocol."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"test"
,
lazy
=
"os:getcwd"
)
# __getitem__
assert
reg
[
"test"
]
==
reg
.
get
(
"test"
)
# __contains__
assert
"test"
in
reg
assert
"missing"
not
in
reg
# __iter__ and __len__ tested above
class
TestTypeConstraints
:
"""Test type checking and base class constraints."""
def
test_base_class_constraint
(
self
):
"""Test base class validation."""
# Define a base class
class
BaseClass
:
pass
class
GoodSubclass
(
BaseClass
):
pass
class
BadClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Should work - correct subclass
@
reg
.
register
(
"good"
)
class
GoodInline
(
BaseClass
):
pass
# Should fail - wrong type
with
pytest
.
raises
(
TypeError
)
as
exc_info
:
@
reg
.
register
(
"bad"
)
class
BadInline
:
pass
assert
"must inherit from"
in
str
(
exc_info
.
value
)
def
test_lazy_type_check
(
self
):
"""Test that type checking happens on materialization for lazy entries."""
class
BaseClass
:
pass
reg
=
Registry
(
"typed"
,
base_cls
=
BaseClass
)
# Register a lazy entry that will fail type check
reg
.
register
(
"bad_lazy"
,
lazy
=
"os.path:join"
)
# Should fail when accessed - the error message varies
with
pytest
.
raises
(
TypeError
):
reg
.
get
(
"bad_lazy"
)
class
TestCollisionHandling
:
"""Test registration collision scenarios."""
def
test_identical_registration
(
self
):
"""Test that identical re-registration is allowed."""
reg
=
Registry
(
"test"
)
class
MyClass
:
pass
# First registration
reg
.
register
(
"test"
,
lazy
=
MyClass
)
# Identical re-registration should work
reg
.
register
(
"test"
,
lazy
=
MyClass
)
assert
reg
.
get
(
"test"
)
==
MyClass
def
test_different_registration_fails
(
self
):
"""Test that different re-registration fails."""
reg
=
Registry
(
"test"
)
class
Class1
:
pass
class
Class2
:
pass
reg
.
register
(
"test"
,
lazy
=
Class1
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"test"
,
lazy
=
Class2
)
assert
"already registered"
in
str
(
exc_info
.
value
)
def
test_lazy_to_concrete_upgrade
(
self
):
"""Test that lazy placeholder can be upgraded to concrete class."""
reg
=
Registry
(
"test"
)
# Register lazy
reg
.
register
(
"myclass"
,
lazy
=
"test_registry:MyUpgradeClass"
)
# Define and register concrete - should work
@
reg
.
register
(
"myclass"
)
class
MyUpgradeClass
:
pass
assert
reg
.
get
(
"myclass"
)
==
MyUpgradeClass
class
TestThreadSafety
:
"""Test thread safety of registry operations."""
def
test_concurrent_access
(
self
):
"""Test concurrent access to lazy entries."""
reg
=
Registry
(
"test"
)
# Register lazy entry
reg
.
register
(
"concurrent"
,
lazy
=
"os.path:join"
)
results
=
[]
errors
=
[]
def
access_item
():
try
:
result
=
reg
.
get
(
"concurrent"
)
results
.
append
(
result
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads
threads
=
[]
for
_
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
access_item
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
results
)
==
10
# All should get the same object
assert
all
(
r
==
results
[
0
]
for
r
in
results
)
def
test_concurrent_registration
(
self
):
"""Test concurrent registration doesn't cause issues."""
reg
=
Registry
(
"test"
)
errors
=
[]
def
register_item
(
name
,
value
):
try
:
reg
.
register
(
name
,
lazy
=
value
)
except
Exception
as
e
:
errors
.
append
(
str
(
e
))
# Launch threads with different registrations
threads
=
[]
for
i
in
range
(
10
):
t
=
threading
.
Thread
(
target
=
register_item
,
args
=
(
f
"item_
{
i
}
"
,
f
"module
{
i
}
:Class
{
i
}
"
)
)
threads
.
append
(
t
)
t
.
start
()
# Wait for completion
for
t
in
threads
:
t
.
join
()
# Check results
assert
len
(
errors
)
==
0
assert
len
(
reg
)
==
10
class
TestMetricRegistry
:
"""Test metric-specific registry functionality."""
def
test_metric_spec
(
self
):
"""Test MetricSpec dataclass."""
def
compute_fn
(
items
):
return
[
1
for
_
in
items
]
def
agg_fn
(
values
):
return
sum
(
values
)
/
len
(
values
)
spec
=
MetricSpec
(
compute
=
compute_fn
,
aggregate
=
agg_fn
,
higher_is_better
=
True
,
output_type
=
"probability"
,
)
assert
spec
.
compute
==
compute_fn
assert
spec
.
aggregate
==
agg_fn
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"probability"
def
test_register_metric_decorator
(
self
):
"""Test @register_metric decorator."""
# Register aggregation function first
@
metric_agg_registry
.
register
(
"test_mean"
)
def
test_mean
(
values
):
return
sum
(
values
)
/
len
(
values
)
if
values
else
0.0
# Register metric
@
register_metric
(
metric
=
"test_accuracy"
,
aggregation
=
"test_mean"
,
higher_is_better
=
True
,
output_type
=
"accuracy"
,
)
def
compute_accuracy
(
items
):
return
[
1
if
item
[
"pred"
]
==
item
[
"gold"
]
else
0
for
item
in
items
]
# Check registration
assert
"test_accuracy"
in
metric_registry
spec
=
metric_registry
.
get
(
"test_accuracy"
)
assert
isinstance
(
spec
,
MetricSpec
)
assert
spec
.
higher_is_better
assert
spec
.
output_type
==
"accuracy"
# Test compute function
items
=
[
{
"pred"
:
"a"
,
"gold"
:
"a"
},
{
"pred"
:
"b"
,
"gold"
:
"b"
},
{
"pred"
:
"c"
,
"gold"
:
"d"
},
]
result
=
spec
.
compute
(
items
)
assert
result
==
[
1
,
1
,
0
]
# Test aggregation
agg_result
=
spec
.
aggregate
(
result
)
assert
agg_result
==
2
/
3
def
test_metric_without_aggregation
(
self
):
"""Test metric registration without aggregation."""
@
register_metric
(
metric
=
"no_agg"
,
higher_is_better
=
False
)
def
compute_something
(
items
):
return
[
len
(
item
)
for
item
in
items
]
spec
=
metric_registry
.
get
(
"no_agg"
)
# Should raise NotImplementedError when aggregate is called
with
pytest
.
raises
(
NotImplementedError
)
as
exc_info
:
spec
.
aggregate
([
1
,
2
,
3
])
assert
"No aggregation function specified"
in
str
(
exc_info
.
value
)
def
test_get_metric_helper
(
self
):
"""Test get_metric helper function."""
@
register_metric
(
metric
=
"helper_test"
,
aggregation
=
"mean"
,
# Assuming 'mean' exists in metric_agg_registry
)
def
compute_helper
(
items
):
return
items
# get_metric returns just the compute function
compute_fn
=
get_metric
(
"helper_test"
)
assert
callable
(
compute_fn
)
assert
compute_fn
([
1
,
2
,
3
])
==
[
1
,
2
,
3
]
class
TestRegistryUtilities
:
"""Test utility methods."""
def
test_freeze
(
self
):
"""Test freezing a registry."""
reg
=
Registry
(
"test"
)
# Add some items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
# Freeze the registry
reg
.
freeze
()
# Should not be able to register new items
with
pytest
.
raises
(
TypeError
):
reg
.
_objs
[
"new"
]
=
"value"
# Should still be able to access items
assert
"item1"
in
reg
assert
callable
(
reg
.
get
(
"item1"
))
def
test_clear
(
self
):
"""Test clearing a registry."""
reg
=
Registry
(
"test"
)
# Add items
reg
.
register
(
"item1"
,
lazy
=
"os:getcwd"
)
reg
.
register
(
"item2"
,
lazy
=
"os:getenv"
)
assert
len
(
reg
)
==
2
# Clear
reg
.
_clear
()
assert
len
(
reg
)
==
0
assert
list
(
reg
)
==
[]
def
test_origin
(
self
):
"""Test origin tracking."""
reg
=
Registry
(
"test"
)
# Lazy entry - no origin
reg
.
register
(
"lazy"
,
lazy
=
"os:getcwd"
)
assert
reg
.
origin
(
"lazy"
)
is
None
# Concrete class - should have origin
@
reg
.
register
(
"concrete"
)
class
ConcreteClass
:
pass
origin
=
reg
.
origin
(
"concrete"
)
assert
origin
is
not
None
assert
"test_registry.py"
in
origin
assert
":"
in
origin
# Has line number
class
TestBackwardCompatibility
:
"""Test backward compatibility features."""
def
test_model_registry_alias
(
self
):
"""Test MODEL_REGISTRY backward compatibility."""
from
lm_eval.api.registry
import
MODEL_REGISTRY
# Should be same object as model_registry
assert
MODEL_REGISTRY
is
model_registry
# Should reflect current state
before_count
=
len
(
MODEL_REGISTRY
)
# Add new model
@
model_registry
.
register
(
"test_model_compat"
)
class
TestModelCompat
(
LM
):
pass
# MODEL_REGISTRY should immediately reflect the change
assert
len
(
MODEL_REGISTRY
)
==
before_count
+
1
assert
"test_model_compat"
in
MODEL_REGISTRY
def
test_legacy_functions
(
self
):
"""Test legacy helper functions."""
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
get_model
,
register_model
,
)
# register_model should work
@
register_model
(
"legacy_model"
)
class
LegacyModel
(
LM
):
pass
# get_model should work
assert
get_model
(
"legacy_model"
)
==
LegacyModel
# Check other aliases
assert
DEFAULT_METRIC_REGISTRY
is
metric_registry
assert
AGGREGATION_REGISTRY
is
metric_agg_registry
class
TestEdgeCases
:
"""Test edge cases and error conditions."""
def
test_invalid_lazy_format
(
self
):
"""Test error on invalid lazy format."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"bad"
,
lazy
=
"no_colon_here"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
get
(
"bad"
)
assert
"expected 'module:object'"
in
str
(
exc_info
.
value
)
def
test_lazy_module_not_found
(
self
):
"""Test error when lazy module doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing"
,
lazy
=
"nonexistent_module:Class"
)
with
pytest
.
raises
(
ModuleNotFoundError
):
reg
.
get
(
"missing"
)
def
test_lazy_attribute_not_found
(
self
):
"""Test error when lazy attribute doesn't exist."""
reg
=
Registry
(
"test"
)
reg
.
register
(
"missing_attr"
,
lazy
=
"os:nonexistent_function"
)
with
pytest
.
raises
(
AttributeError
):
reg
.
get
(
"missing_attr"
)
def
test_multiple_aliases_with_lazy
(
self
):
"""Test that multiple aliases with lazy fails."""
reg
=
Registry
(
"test"
)
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
reg
.
register
(
"alias1"
,
"alias2"
,
lazy
=
"os:getcwd"
)
assert
"Exactly one alias required"
in
str
(
exc_info
.
value
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment