Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
93b2ab37
Commit
93b2ab37
authored
Jul 27, 2025
by
Baber
Browse files
refactor registry
parent
de496b80
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
536 additions
and
168 deletions
+536
-168
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+2
-2
lm_eval/api/registry.py
lm_eval/api/registry.py
+469
-125
lm_eval/api/task.py
lm_eval/api/task.py
+2
-5
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+51
-25
lm_eval/models/hf_steered.py
lm_eval/models/hf_steered.py
+2
-1
lm_eval/models/ibm_watsonx_ai.py
lm_eval/models/ibm_watsonx_ai.py
+2
-2
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+1
-1
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
+3
-3
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
+3
-3
scripts/build_benchmark.py
scripts/build_benchmark.py
+1
-1
No files found.
lm_eval/api/metrics.py
View file @
93b2ab37
...
@@ -4,8 +4,8 @@ import os
...
@@ -4,8 +4,8 @@ import os
import
random
import
random
import
re
import
re
import
string
import
string
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
,
Sequence
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
from
typing
import
Callable
,
List
,
Optional
,
TypeVar
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
import
sacrebleu
...
...
lm_eval/api/registry.py
View file @
93b2ab37
import
logging
from
__future__
import
annotations
from
typing
import
Callable
,
Dict
,
Union
import
importlib
import
evaluate
as
hf_evaluate
import
inspect
import
threading
from
lm_eval.api.model
import
LM
from
collections.abc
import
Iterable
,
Mapping
,
MutableMapping
from
dataclasses
import
dataclass
from
functools
import
lru_cache
eval_logger
=
logging
.
getLogger
(
__name__
)
from
types
import
MappingProxyType
from
typing
import
(
MODEL_REGISTRY
=
{}
Any
,
Callable
,
Generic
,
def
register_model
(
*
names
):
TypeVar
,
# either pass a list or a single alias.
)
# function receives them as a tuple of strings
def
decorate
(
cls
):
try
:
# Python≥3.10
for
name
in
names
:
import
importlib.metadata
as
md
assert
issubclass
(
cls
,
LM
),
(
except
ImportError
:
# pragma: no cover - fallback for 3.8/3.9 runtimes
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
import
importlib_metadata
as
md
# type: ignore
__all__
=
[
"Registry"
,
"MetricSpec"
,
# concrete registries
"model_registry"
,
"task_registry"
,
"metric_registry"
,
"metric_agg_registry"
,
"higher_is_better_registry"
,
"filter_registry"
,
# helper
"freeze_all"
,
# Legacy compatibility
"DEFAULT_METRIC_REGISTRY"
,
"AGGREGATION_REGISTRY"
,
"register_model"
,
"get_model"
,
"register_task"
,
"get_task"
,
"register_metric"
,
"get_metric"
,
"register_metric_aggregation"
,
"get_metric_aggregation"
,
"register_higher_is_better"
,
"is_higher_better"
,
"register_filter"
,
"get_filter"
,
"register_aggregation"
,
"get_aggregation"
,
"MODEL_REGISTRY"
,
"TASK_REGISTRY"
,
"METRIC_REGISTRY"
,
"METRIC_AGGREGATION_REGISTRY"
,
"HIGHER_IS_BETTER_REGISTRY"
,
"FILTER_REGISTRY"
,
]
T
=
TypeVar
(
"T"
)
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────
class
Registry
(
Generic
[
T
]):
"""Name -> object mapping with decorator helpers and **lazy import** support."""
#: The underlying mutable mapping (might turn into MappingProxy on freeze)
_objects
:
MutableMapping
[
str
,
T
|
str
|
md
.
EntryPoint
]
def
__init__
(
self
,
name
:
str
,
*
,
base_cls
:
type
[
T
]
|
None
=
None
,
store
:
MutableMapping
[
str
,
T
|
str
|
md
.
EntryPoint
]
|
None
=
None
,
validator
:
Callable
[[
T
],
bool
]
|
None
=
None
,
)
->
None
:
self
.
_name
:
str
=
name
self
.
_base_cls
:
type
[
T
]
|
None
=
base_cls
self
.
_objects
=
store
if
store
is
not
None
else
{}
self
.
_metadata
:
dict
[
str
,
dict
[
str
,
Any
]
]
=
{}
# Store metadata for each registered item
self
.
_validator
=
validator
# Custom validation function
self
.
_lock
=
threading
.
RLock
()
# ------------------------------------------------------------------
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
def
register
(
self
,
*
aliases
:
str
,
lazy
:
str
|
md
.
EntryPoint
|
None
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def
_do_register
(
target
:
T
|
str
|
md
.
EntryPoint
)
->
None
:
if
not
aliases
:
_aliases
=
(
getattr
(
target
,
"__name__"
,
str
(
target
)),)
else
:
_aliases
=
aliases
with
self
.
_lock
:
for
alias
in
_aliases
:
if
alias
in
self
.
_objects
:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing
=
self
.
_objects
[
alias
]
if
isinstance
(
existing
,
(
str
,
md
.
EntryPoint
))
and
isinstance
(
target
,
type
):
# Allow replacing lazy placeholder with concrete class
pass
else
:
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
'
{
alias
}
' already registered "
f
"(
{
self
.
_objects
[
alias
]
}
)"
)
# Eager type check only when we have a concrete class
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
"
f
"to be registered as a
{
self
.
_name
}
"
)
self
.
_objects
[
alias
]
=
target
# Store metadata if provided
if
metadata
:
self
.
_metadata
[
alias
]
=
metadata
# ─── decorator path ───
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
_do_register
(
obj
)
return
obj
# ─── direct‑call path with lazy placeholder ───
if
lazy
is
not
None
:
_do_register
(
lazy
)
return
lambda
x
:
x
# no‑op decorator for accidental use
return
decorator
def
register_bulk
(
self
,
items
:
dict
[
str
,
T
|
str
|
md
.
EntryPoint
],
metadata
:
dict
[
str
,
dict
[
str
,
Any
]]
|
None
=
None
,
)
->
None
:
"""Register multiple items at once.
Args:
items: Dictionary mapping aliases to objects/lazy paths
metadata: Optional dictionary mapping aliases to metadata
"""
with
self
.
_lock
:
for
alias
,
target
in
items
.
items
():
if
alias
in
self
.
_objects
:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing
=
self
.
_objects
[
alias
]
if
isinstance
(
existing
,
(
str
,
md
.
EntryPoint
))
and
isinstance
(
target
,
type
):
# Allow replacing lazy placeholder with concrete class
pass
else
:
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
'
{
alias
}
' already registered "
f
"(
{
self
.
_objects
[
alias
]
}
)"
)
)
assert
name
not
in
MODEL_REGISTRY
,
(
# Eager type check only when we have a concrete class
f
"Model named '
{
name
}
' conflicts with existing model! Please register with a non-conflicting alias instead."
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
"
f
"to be registered as a
{
self
.
_name
}
"
)
)
MODEL_REGISTRY
[
name
]
=
cls
self
.
_objects
[
alias
]
=
target
return
cls
return
decorate
# Store metadata if provided
if
metadata
and
alias
in
metadata
:
self
.
_metadata
[
alias
]
=
metadata
[
alias
]
# ------------------------------------------------------------------
# Lookup & materialisation
# ------------------------------------------------------------------
def
get_model
(
model_name
):
@
lru_cache
(
maxsize
=
256
)
# Bounded cache to prevent memory growth
try
:
def
_materialise
(
self
,
target
:
T
|
str
|
md
.
EntryPoint
)
->
T
:
return
MODEL_REGISTRY
[
model_name
]
"""Import *target* if it is a dotted‑path string or EntryPoint."""
except
KeyError
:
if
isinstance
(
target
,
str
):
mod
,
_
,
obj_name
=
target
.
partition
(
":"
)
if
not
_
:
raise
ValueError
(
raise
ValueError
(
f
"Attempted to load model '
{
model_name
}
', but no model for this name found! Supported model names:
{
', '
.
join
(
MODEL_REGISTRY
.
keys
())
}
"
f
"Lazy path '
{
target
}
' must be in 'module:object' form"
)
module
=
importlib
.
import_module
(
mod
)
return
getattr
(
module
,
obj_name
)
if
isinstance
(
target
,
md
.
EntryPoint
):
return
target
.
load
()
return
target
# concrete already
def
get
(
self
,
alias
:
str
)
->
T
:
with
self
.
_lock
:
try
:
target
=
self
.
_objects
[
alias
]
except
KeyError
as
exc
:
raise
KeyError
(
f
"Unknown
{
self
.
_name
}
'
{
alias
}
'. Available: "
f
"
{
', '
.
join
(
self
.
_objects
)
}
"
)
from
exc
# Only materialize if it's a string or EntryPoint (lazy placeholder)
if
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
concrete
:
T
=
self
.
_materialise
(
target
)
# First‑touch: swap placeholder with concrete obj for future calls
if
concrete
is
not
target
:
self
.
_objects
[
alias
]
=
concrete
else
:
# Already materialized, just return it
concrete
=
target
# Late type check (for placeholders)
if
self
.
_base_cls
is
not
None
and
not
issubclass
(
concrete
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
concrete
}
does not inherit from
{
self
.
_base_cls
}
"
f
"(registered under alias '
{
alias
}
')"
)
)
# Custom validation
TASK_REGISTRY
=
{}
if
self
.
_validator
is
not
None
and
not
self
.
_validator
(
concrete
):
GROUP_REGISTRY
=
{}
raise
ValueError
(
ALL_TASKS
=
set
()
f
"
{
concrete
}
failed custom validation for
{
self
.
_name
}
registry "
func2task_index
=
{}
f
"(registered under alias '
{
alias
}
')"
def
register_task
(
name
):
def
decorate
(
fn
):
assert
name
not
in
TASK_REGISTRY
,
(
f
"task named '
{
name
}
' conflicts with existing registered task!"
)
)
TASK_REGISTRY
[
name
]
=
fn
return
concrete
ALL_TASKS
.
add
(
name
)
func2task_index
[
fn
.
__name__
]
=
name
return
fn
return
decorate
# Mapping / dunder helpers -------------------------------------------------
def
__getitem__
(
self
,
alias
:
str
)
->
T
:
# noqa
return
self
.
get
(
alias
)
def
register_group
(
name
):
def
__iter__
(
self
):
# noqa
def
decorate
(
fn
):
return
iter
(
self
.
_objects
)
func_name
=
func2task_index
[
fn
.
__name__
]
if
name
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
name
].
append
(
func_name
)
else
:
GROUP_REGISTRY
[
name
]
=
[
func_name
]
ALL_TASKS
.
add
(
name
)
return
fn
return
decorate
def
__len__
(
self
)
->
int
:
# noqa
return
len
(
self
.
_objects
)
def
items
(
self
):
# noqa
return
self
.
_objects
.
items
()
OUTPUT_TYPE_REGISTRY
=
{}
# Introspection -----------------------------------------------------------
METRIC_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
AGGREGATION_REGISTRY
:
Dict
[
str
,
Callable
[[],
Dict
[
str
,
Callable
]]]
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
def
origin
(
self
,
alias
:
str
)
->
str
|
None
:
obj
=
self
.
_objects
.
get
(
alias
)
try
:
if
isinstance
(
obj
,
str
)
or
isinstance
(
obj
,
md
.
EntryPoint
):
return
None
# placeholder - unknown until imported
file
=
inspect
.
getfile
(
obj
)
# type: ignore[arg-type]
line
=
inspect
.
getsourcelines
(
obj
)[
1
]
# type: ignore[arg-type]
return
f
"
{
file
}
:
{
line
}
"
except
(
TypeError
,
OSError
,
AttributeError
,
):
# pragma: no cover - best-effort only
# TypeError: object not suitable for inspect
# OSError: file not found or accessible
# AttributeError: object lacks expected attributes
return
None
def
get_metadata
(
self
,
alias
:
str
)
->
dict
[
str
,
Any
]
|
None
:
"""Get metadata for a registered item."""
with
self
.
_lock
:
return
self
.
_metadata
.
get
(
alias
)
# Mutability --------------------------------------------------------------
def
freeze
(
self
):
"""Make the registry *names* immutable (materialisation still works)."""
with
self
.
_lock
:
if
isinstance
(
self
.
_objects
,
MappingProxyType
):
return
# already frozen
self
.
_objects
=
MappingProxyType
(
dict
(
self
.
_objects
))
# type: ignore[assignment]
def
clear
(
self
):
"""Clear the registry (useful for tests). Cannot be called on frozen registries."""
with
self
.
_lock
:
if
isinstance
(
self
.
_objects
,
MappingProxyType
):
raise
RuntimeError
(
"Cannot clear a frozen registry"
)
self
.
_objects
.
clear
()
self
.
_metadata
.
clear
()
self
.
_materialise
.
cache_clear
()
# type: ignore[attr-defined] # Added by lru_cache
# ────────────────────────────────────────────────────────────────────────
# Structured objects stored in registries
# ────────────────────────────────────────────────────────────────────────
@
dataclass
(
frozen
=
True
)
class
MetricSpec
:
"""Bundle compute fn, aggregator, and *higher‑is‑better* flag."""
compute
:
Callable
[[
Any
,
Any
],
Any
]
aggregate
:
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]
higher_is_better
:
bool
=
True
output_type
:
str
|
None
=
None
# e.g., "probability", "string", "numeric"
requires
:
list
[
str
]
|
None
=
None
# Dependencies on other metrics/data
# ────────────────────────────────────────────────────────────────────────
# Concrete registries used by lm_eval
# ────────────────────────────────────────────────────────────────────────
from
lm_eval.api.model
import
LM
# noqa: E402
model_registry
:
Registry
[
type
[
LM
]]
=
Registry
(
"model"
,
base_cls
=
LM
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]]
=
(
Registry
(
"metric aggregation"
)
)
higher_is_better_registry
:
Registry
[
bool
]
=
Registry
(
"higher‑is‑better flag"
)
filter_registry
:
Registry
[
Callable
]
=
Registry
(
"filter"
)
# Default metric registry for output types
DEFAULT_METRIC_REGISTRY
=
{
DEFAULT_METRIC_REGISTRY
=
{
"loglikelihood"
:
[
"loglikelihood"
:
[
"perplexity"
,
"perplexity"
,
...
@@ -90,107 +347,194 @@ DEFAULT_METRIC_REGISTRY = {
...
@@ -90,107 +347,194 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until"
:
[
"exact_match"
],
"generate_until"
:
[
"exact_match"
],
}
}
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY
:
dict
[
str
,
Callable
]
=
{}
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
# ────────────────────────────────────────────────────────────────────────
register_model
=
model_registry
.
register
get_model
=
model_registry
.
get
register_task
=
task_registry
.
register
get_task
=
task_registry
.
get
# Special handling for metric registration which uses different API
def
register_metric
(
**
kwargs
):
"""Register a metric with metadata.
Compatible with old registry API that used keyword arguments.
"""
def
register_metric
(
**
args
):
# TODO: do we want to enforce a certain interface to registered metrics?
def
decorate
(
fn
):
def
decorate
(
fn
):
assert
"metric"
in
args
metric_name
=
kwargs
.
get
(
"metric"
)
name
=
args
[
"metric"
]
if
not
metric_name
:
raise
ValueError
(
"metric name is required"
)
for
key
,
registry
in
[
(
"metric"
,
METRIC_REGISTRY
),
# Create MetricSpec with the function and metadata
(
"higher_is_better"
,
HIGHER_IS_BETTER_REGISTRY
),
spec
=
MetricSpec
(
(
"aggregation"
,
METRIC_AGGREGATION_REGISTRY
),
compute
=
fn
,
]:
aggregate
=
lambda
x
:
{},
# Default aggregation returns empty dict
if
key
in
args
:
higher_is_better
=
kwargs
.
get
(
"higher_is_better"
,
True
),
value
=
args
[
key
]
output_type
=
kwargs
.
get
(
"output_type"
),
assert
value
not
in
registry
,
(
requires
=
kwargs
.
get
(
"requires"
),
f
"
{
key
}
named '
{
value
}
' conflicts with existing registered
{
key
}
!"
)
)
if
key
==
"metric"
:
# Register in metric registry
registry
[
name
]
=
fn
metric_registry
.
_objects
[
metric_name
]
=
spec
elif
key
==
"aggregation"
:
registry
[
name
]
=
AGGREGATION_REGISTRY
[
value
]
# Also handle aggregation if specified
else
:
if
"aggregation"
in
kwargs
:
registry
[
name
]
=
value
agg_name
=
kwargs
[
"aggregation"
]
# Try to get aggregation from AGGREGATION_REGISTRY
if
agg_name
in
AGGREGATION_REGISTRY
:
spec
=
MetricSpec
(
compute
=
fn
,
aggregate
=
AGGREGATION_REGISTRY
[
agg_name
],
higher_is_better
=
kwargs
.
get
(
"higher_is_better"
,
True
),
output_type
=
kwargs
.
get
(
"output_type"
),
requires
=
kwargs
.
get
(
"requires"
),
)
metric_registry
.
_objects
[
metric_name
]
=
spec
# Handle higher_is_better registry
if
"higher_is_better"
in
kwargs
:
higher_is_better_registry
.
_objects
[
metric_name
]
=
kwargs
[
"higher_is_better"
]
return
fn
return
fn
return
decorate
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
):
"""Get a metric by name, with fallback to HF evaluate."""
if
not
hf_evaluate_metric
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
try
:
return
METRIC_REGISTRY
[
name
]
spec
=
metric_registry
.
get
(
name
)
else
:
if
isinstance
(
spec
,
MetricSpec
):
eval_logger
.
warning
(
return
spec
.
compute
return
spec
except
KeyError
:
import
logging
logging
.
getLogger
(
__name__
).
warning
(
f
"Could not find registered metric '
{
name
}
' in lm-eval, searching in HF Evaluate library..."
f
"Could not find registered metric '
{
name
}
' in lm-eval, searching in HF Evaluate library..."
)
)
# Fallback to HF evaluate
try
:
try
:
import
evaluate
as
hf_evaluate
metric_object
=
hf_evaluate
.
load
(
name
)
metric_object
=
hf_evaluate
.
load
(
name
)
return
metric_object
.
compute
return
metric_object
.
compute
except
Exception
:
except
Exception
:
eval_logger
.
error
(
import
logging
logging
.
getLogger
(
__name__
).
error
(
f
"
{
name
}
not found in the evaluate library! Please check https://huggingface.co/evaluate-metric"
,
f
"
{
name
}
not found in the evaluate library! Please check https://huggingface.co/evaluate-metric"
,
)
)
return
None
def
register_aggregation
(
name
:
str
):
register_metric_aggregation
=
metric_agg_registry
.
register
def
decorate
(
fn
):
assert
name
not
in
AGGREGATION_REGISTRY
,
(
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
def
get_metric_aggregation
(
metric_name
:
str
):
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
if
metric_name
in
metric_registry
.
_objects
:
metric_spec
=
metric_registry
.
_objects
[
metric_name
]
if
isinstance
(
metric_spec
,
MetricSpec
)
and
metric_spec
.
aggregate
:
return
metric_spec
.
aggregate
def
get_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
# Fall back to metric_agg_registry (for standalone aggregations)
try
:
if
metric_name
in
metric_agg_registry
.
_objects
:
return
AGGREGATION_REGISTRY
[
name
]
return
metric_agg_registry
.
_objects
[
metric_name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
# If not found, raise error
raise
KeyError
(
f
"Unknown metric aggregation '
{
metric_name
}
'. Available:
{
list
(
AGGREGATION_REGISTRY
.
keys
())
}
"
)
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
register_higher_is_better
=
higher_is_better_registry
.
register
is_higher_better
=
higher_is_better_registry
.
get
def
is_higher_better
(
metric_name
)
->
bool
:
register_filter
=
filter_registry
.
register
try
:
get_filter
=
filter_registry
.
get
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
eval_logger
.
warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!"
)
def
register_filter
(
name
):
# Special handling for AGGREGATION_REGISTRY which works differently
def
decorate
(
cls
):
def
register_aggregation
(
name
:
str
):
if
name
in
FILTER_REGISTRY
:
def
decorate
(
fn
):
eval_logger
.
info
(
if
name
in
AGGREGATION_REGISTRY
:
f
"Registering filter `
{
name
}
` that is already in Registry
{
FILTER_REGISTRY
}
"
raise
ValueError
(
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
)
)
FILTER
_REGISTRY
[
name
]
=
cls
AGGREGATION
_REGISTRY
[
name
]
=
fn
return
cls
return
fn
return
decorate
return
decorate
def
get_
filter
(
filter_name
:
Union
[
str
,
Callable
])
->
Callable
:
def
get_
aggregation
(
name
:
str
)
->
Callable
[[],
dict
[
str
,
Callable
]]
:
try
:
try
:
return
FILTER_REGISTRY
[
filter_name
]
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
as
e
:
except
KeyError
:
if
callable
(
filter_name
):
import
logging
return
filter_name
else
:
logging
.
getLogger
(
__name__
).
warning
(
eval_logger
.
warning
(
f
"filter `
{
filter_name
}
` is not registered!"
)
f
"
{
name
}
not a registered aggregation metric!"
raise
e
)
return
None
# ────────────────────────────────────────────────────────────────────────
# Optional PyPI entry‑point discovery - uncomment if desired
# ────────────────────────────────────────────────────────────────────────
# for _group, _reg in {
# "lm_eval.models": model_registry,
# "lm_eval.tasks": task_registry,
# "lm_eval.metrics": metric_registry,
# }.items():
# for _ep in md.entry_points(group=_group):
# _reg.register(_ep.name, lazy=_ep)
# ────────────────────────────────────────────────────────────────────────
# Convenience
# ────────────────────────────────────────────────────────────────────────
def
freeze_all
()
->
None
:
# pragma: no cover
"""Freeze every global registry (idempotent)."""
for
_reg
in
(
model_registry
,
task_registry
,
metric_registry
,
metric_agg_registry
,
higher_is_better_registry
,
filter_registry
,
):
_reg
.
freeze
()
# ────────────────────────────────────────────────────────────────────────
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY
:
Mapping
[
str
,
type
[
LM
]]
=
MappingProxyType
(
model_registry
.
_objects
)
# type: ignore[attr-defined]
TASK_REGISTRY
:
Mapping
[
str
,
Callable
[...,
Any
]]
=
MappingProxyType
(
task_registry
.
_objects
)
# type: ignore[attr-defined]
METRIC_REGISTRY
:
Mapping
[
str
,
MetricSpec
]
=
MappingProxyType
(
metric_registry
.
_objects
)
# type: ignore[attr-defined]
METRIC_AGGREGATION_REGISTRY
:
Mapping
[
str
,
Callable
]
=
MappingProxyType
(
metric_agg_registry
.
_objects
)
# type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY
:
Mapping
[
str
,
bool
]
=
MappingProxyType
(
higher_is_better_registry
.
_objects
)
# type: ignore[attr-defined]
FILTER_REGISTRY
:
Mapping
[
str
,
Callable
]
=
MappingProxyType
(
filter_registry
.
_objects
)
# type: ignore[attr-defined]
lm_eval/api/task.py
View file @
93b2ab37
...
@@ -3,18 +3,15 @@ import ast
...
@@ -3,18 +3,15 @@ import ast
import
logging
import
logging
import
random
import
random
import
re
import
re
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Iterable
,
Iterator
,
Mapping
from
copy
import
deepcopy
from
copy
import
deepcopy
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
(
from
typing
import
(
Any
,
Any
,
Dict
,
Dict
,
Iterable
,
Iterator
,
List
,
List
,
Literal
,
Literal
,
Mapping
,
Optional
,
Optional
,
Tuple
,
Tuple
,
Union
,
Union
,
...
@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task):
...
@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task):
Instance
(
Instance
(
request_type
=
"loglikelihood"
,
request_type
=
"loglikelihood"
,
doc
=
doc
,
doc
=
doc
,
arguments
=
(
ctx
,
" {
}"
.
format
(
choice
)
),
arguments
=
(
ctx
,
f
"
{
choice
}
"
),
idx
=
i
,
idx
=
i
,
**
kwargs
,
**
kwargs
,
)
)
...
...
lm_eval/models/__init__.py
View file @
93b2ab37
from
.
import
(
# Models are now lazily loaded via the registry system
anthropic_llms
,
# No need to import them all at once
api_models
,
dummy
,
# Define model mappings for lazy registration
gguf
,
MODEL_MAPPING
=
{
hf_audiolm
,
"anthropic-completions"
:
"lm_eval.models.anthropic_llms:AnthropicLM"
,
hf_steered
,
"anthropic-chat"
:
"lm_eval.models.anthropic_llms:AnthropicChatLM"
,
hf_vlms
,
"anthropic-chat-completions"
:
"lm_eval.models.anthropic_llms:AnthropicCompletionsLM"
,
huggingface
,
"local-completions"
:
"lm_eval.models.openai_completions:LocalCompletionsAPI"
,
ibm_watsonx_ai
,
"local-chat-completions"
:
"lm_eval.models.openai_completions:LocalChatCompletion"
,
mamba_lm
,
"openai-completions"
:
"lm_eval.models.openai_completions:OpenAICompletionsAPI"
,
nemo_lm
,
"openai-chat-completions"
:
"lm_eval.models.openai_completions:OpenAIChatCompletion"
,
neuron_optimum
,
"dummy"
:
"lm_eval.models.dummy:DummyLM"
,
openai_completions
,
"gguf"
:
"lm_eval.models.gguf:GGUFLM"
,
optimum_ipex
,
"ggml"
:
"lm_eval.models.gguf:GGUFLM"
,
optimum_lm
,
"hf-audiolm-qwen"
:
"lm_eval.models.hf_audiolm:HFAudioLM"
,
sglang_causallms
,
"steered"
:
"lm_eval.models.hf_steered:SteeredHF"
,
sglang_generate_API
,
"hf-multimodal"
:
"lm_eval.models.hf_vlms:HFMultimodalLM"
,
textsynth
,
"hf-auto"
:
"lm_eval.models.huggingface:HFLM"
,
vllm_causallms
,
"hf"
:
"lm_eval.models.huggingface:HFLM"
,
vllm_vlms
,
"huggingface"
:
"lm_eval.models.huggingface:HFLM"
,
)
"watsonx_llm"
:
"lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI"
,
"mamba_ssm"
:
"lm_eval.models.mamba_lm:MambaLMWrapper"
,
"nemo_lm"
:
"lm_eval.models.nemo_lm:NeMoLM"
,
# TODO: implement __all__
"neuronx"
:
"lm_eval.models.neuron_optimum:NeuronModelForCausalLM"
,
"ipex"
:
"lm_eval.models.optimum_ipex:IPEXForCausalLM"
,
"openvino"
:
"lm_eval.models.optimum_lm:OptimumLM"
,
"sglang"
:
"lm_eval.models.sglang_causallms:SGLANG"
,
"sglang-generate"
:
"lm_eval.models.sglang_generate_API:SGAPI"
,
"textsynth"
:
"lm_eval.models.textsynth:TextSynthLM"
,
"vllm"
:
"lm_eval.models.vllm_causallms:VLLM"
,
"vllm-vlm"
:
"lm_eval.models.vllm_vlms:VLLM_VLM"
,
}
# Register all models lazily
def
_register_all_models
():
"""Register all known models lazily in the registry."""
from
lm_eval.api.registry
import
model_registry
for
name
,
path
in
MODEL_MAPPING
.
items
():
# Only register if not already present (avoids conflicts when modules are imported)
if
name
not
in
model_registry
:
# Call register with the lazy parameter, returns a decorator
model_registry
.
register
(
name
,
lazy
=
path
)(
None
)
# Call registration on module import
_register_all_models
()
__all__
=
[
"MODEL_MAPPING"
]
try
:
try
:
...
...
lm_eval/models/hf_steered.py
View file @
93b2ab37
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
from
peft.peft_model
import
PeftModel
from
peft.peft_model
import
PeftModel
...
...
lm_eval/models/ibm_watsonx_ai.py
View file @
93b2ab37
...
@@ -3,7 +3,7 @@ import json
...
@@ -3,7 +3,7 @@ import json
import
logging
import
logging
import
os
import
os
import
warnings
import
warnings
from
functools
import
lru_
cache
from
functools
import
cache
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise
ValueError
(
error_msg
)
raise
ValueError
(
error_msg
)
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
"""
"""
Retrieves Watsonx API credentials from environmental variables.
Retrieves Watsonx API credentials from environmental variables.
...
...
lm_eval/models/vllm_causallms.py
View file @
93b2ab37
...
@@ -40,7 +40,7 @@ try:
...
@@ -40,7 +40,7 @@ try:
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
p
ass
p
rint
(
"njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd"
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
pass
pass
...
...
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
View file @
93b2ab37
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
with
open
(
GRAMMAR_FILE
)
as
f
:
...
@@ -556,8 +556,8 @@ class STRIPS:
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
View file @
93b2ab37
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
with
open
(
GRAMMAR_FILE
)
as
f
:
...
@@ -556,8 +556,8 @@ class STRIPS:
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
scripts/build_benchmark.py
View file @
93b2ab37
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from
tqdm
import
tqdm
from
tqdm
import
tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registry
v2
import ALL_TASKS
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment