Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
93b2ab37
Commit
93b2ab37
authored
Jul 27, 2025
by
Baber
Browse files
refactor registry
parent
de496b80
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
536 additions
and
168 deletions
+536
-168
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+2
-2
lm_eval/api/registry.py
lm_eval/api/registry.py
+469
-125
lm_eval/api/task.py
lm_eval/api/task.py
+2
-5
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+51
-25
lm_eval/models/hf_steered.py
lm_eval/models/hf_steered.py
+2
-1
lm_eval/models/ibm_watsonx_ai.py
lm_eval/models/ibm_watsonx_ai.py
+2
-2
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+1
-1
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
+3
-3
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
+3
-3
scripts/build_benchmark.py
scripts/build_benchmark.py
+1
-1
No files found.
lm_eval/api/metrics.py
View file @
93b2ab37
...
...
@@ -4,8 +4,8 @@ import os
import
random
import
re
import
string
from
collections.abc
import
Iterable
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
from
collections.abc
import
Iterable
,
Sequence
from
typing
import
Callable
,
List
,
Optional
,
TypeVar
import
numpy
as
np
import
sacrebleu
...
...
lm_eval/api/registry.py
View file @
93b2ab37
import
logging
from
typing
import
Callable
,
Dict
,
Union
from
__future__
import
annotations
import
importlib
import
inspect
import
threading
from
collections.abc
import
Iterable
,
Mapping
,
MutableMapping
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
types
import
MappingProxyType
from
typing
import
(
Any
,
Callable
,
Generic
,
TypeVar
,
)
try
:
# Python≥3.10
import
importlib.metadata
as
md
except
ImportError
:
# pragma: no cover - fallback for 3.8/3.9 runtimes
import
importlib_metadata
as
md
# type: ignore
__all__
=
[
"Registry"
,
"MetricSpec"
,
# concrete registries
"model_registry"
,
"task_registry"
,
"metric_registry"
,
"metric_agg_registry"
,
"higher_is_better_registry"
,
"filter_registry"
,
# helper
"freeze_all"
,
# Legacy compatibility
"DEFAULT_METRIC_REGISTRY"
,
"AGGREGATION_REGISTRY"
,
"register_model"
,
"get_model"
,
"register_task"
,
"get_task"
,
"register_metric"
,
"get_metric"
,
"register_metric_aggregation"
,
"get_metric_aggregation"
,
"register_higher_is_better"
,
"is_higher_better"
,
"register_filter"
,
"get_filter"
,
"register_aggregation"
,
"get_aggregation"
,
"MODEL_REGISTRY"
,
"TASK_REGISTRY"
,
"METRIC_REGISTRY"
,
"METRIC_AGGREGATION_REGISTRY"
,
"HIGHER_IS_BETTER_REGISTRY"
,
"FILTER_REGISTRY"
,
]
T
=
TypeVar
(
"T"
)
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# ────────────────────────────────────────────────────────────────────────
class
Registry
(
Generic
[
T
]):
"""Name -> object mapping with decorator helpers and **lazy import** support."""
#: The underlying mutable mapping (might turn into MappingProxy on freeze)
_objects
:
MutableMapping
[
str
,
T
|
str
|
md
.
EntryPoint
]
def
__init__
(
self
,
name
:
str
,
*
,
base_cls
:
type
[
T
]
|
None
=
None
,
store
:
MutableMapping
[
str
,
T
|
str
|
md
.
EntryPoint
]
|
None
=
None
,
validator
:
Callable
[[
T
],
bool
]
|
None
=
None
,
)
->
None
:
self
.
_name
:
str
=
name
self
.
_base_cls
:
type
[
T
]
|
None
=
base_cls
self
.
_objects
=
store
if
store
is
not
None
else
{}
self
.
_metadata
:
dict
[
str
,
dict
[
str
,
Any
]
]
=
{}
# Store metadata for each registered item
self
.
_validator
=
validator
# Custom validation function
self
.
_lock
=
threading
.
RLock
()
# ------------------------------------------------------------------
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
def
register
(
self
,
*
aliases
:
str
,
lazy
:
str
|
md
.
EntryPoint
|
None
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def
_do_register
(
target
:
T
|
str
|
md
.
EntryPoint
)
->
None
:
if
not
aliases
:
_aliases
=
(
getattr
(
target
,
"__name__"
,
str
(
target
)),)
else
:
_aliases
=
aliases
with
self
.
_lock
:
for
alias
in
_aliases
:
if
alias
in
self
.
_objects
:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing
=
self
.
_objects
[
alias
]
if
isinstance
(
existing
,
(
str
,
md
.
EntryPoint
))
and
isinstance
(
target
,
type
):
# Allow replacing lazy placeholder with concrete class
pass
else
:
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
'
{
alias
}
' already registered "
f
"(
{
self
.
_objects
[
alias
]
}
)"
)
# Eager type check only when we have a concrete class
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
"
f
"to be registered as a
{
self
.
_name
}
"
)
self
.
_objects
[
alias
]
=
target
# Store metadata if provided
if
metadata
:
self
.
_metadata
[
alias
]
=
metadata
# ─── decorator path ───
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
_do_register
(
obj
)
return
obj
# ─── direct‑call path with lazy placeholder ───
if
lazy
is
not
None
:
_do_register
(
lazy
)
return
lambda
x
:
x
# no‑op decorator for accidental use
return
decorator
def
register_bulk
(
self
,
items
:
dict
[
str
,
T
|
str
|
md
.
EntryPoint
],
metadata
:
dict
[
str
,
dict
[
str
,
Any
]]
|
None
=
None
,
)
->
None
:
"""Register multiple items at once.
Args:
items: Dictionary mapping aliases to objects/lazy paths
metadata: Optional dictionary mapping aliases to metadata
"""
with
self
.
_lock
:
for
alias
,
target
in
items
.
items
():
if
alias
in
self
.
_objects
:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing
=
self
.
_objects
[
alias
]
if
isinstance
(
existing
,
(
str
,
md
.
EntryPoint
))
and
isinstance
(
target
,
type
):
# Allow replacing lazy placeholder with concrete class
pass
else
:
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
'
{
alias
}
' already registered "
f
"(
{
self
.
_objects
[
alias
]
}
)"
)
# Eager type check only when we have a concrete class
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
"
f
"to be registered as a
{
self
.
_name
}
"
)
self
.
_objects
[
alias
]
=
target
# Store metadata if provided
if
metadata
and
alias
in
metadata
:
self
.
_metadata
[
alias
]
=
metadata
[
alias
]
# ------------------------------------------------------------------
# Lookup & materialisation
# ------------------------------------------------------------------
@
lru_cache
(
maxsize
=
256
)
# Bounded cache to prevent memory growth
def
_materialise
(
self
,
target
:
T
|
str
|
md
.
EntryPoint
)
->
T
:
"""Import *target* if it is a dotted‑path string or EntryPoint."""
if
isinstance
(
target
,
str
):
mod
,
_
,
obj_name
=
target
.
partition
(
":"
)
if
not
_
:
raise
ValueError
(
f
"Lazy path '
{
target
}
' must be in 'module:object' form"
)
module
=
importlib
.
import_module
(
mod
)
return
getattr
(
module
,
obj_name
)
if
isinstance
(
target
,
md
.
EntryPoint
):
return
target
.
load
()
return
target
# concrete already
def
get
(
self
,
alias
:
str
)
->
T
:
with
self
.
_lock
:
try
:
target
=
self
.
_objects
[
alias
]
except
KeyError
as
exc
:
raise
KeyError
(
f
"Unknown
{
self
.
_name
}
'
{
alias
}
'. Available: "
f
"
{
', '
.
join
(
self
.
_objects
)
}
"
)
from
exc
# Only materialize if it's a string or EntryPoint (lazy placeholder)
if
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
concrete
:
T
=
self
.
_materialise
(
target
)
# First‑touch: swap placeholder with concrete obj for future calls
if
concrete
is
not
target
:
self
.
_objects
[
alias
]
=
concrete
else
:
# Already materialized, just return it
concrete
=
target
# Late type check (for placeholders)
if
self
.
_base_cls
is
not
None
and
not
issubclass
(
concrete
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
concrete
}
does not inherit from
{
self
.
_base_cls
}
"
f
"(registered under alias '
{
alias
}
')"
)
import
evaluate
as
hf_evaluate
# Custom validation
if
self
.
_validator
is
not
None
and
not
self
.
_validator
(
concrete
):
raise
ValueError
(
f
"
{
concrete
}
failed custom validation for
{
self
.
_name
}
registry "
f
"(registered under alias '
{
alias
}
')"
)
from
lm_eval.api.model
import
LM
return
concrete
# Mapping / dunder helpers -------------------------------------------------
eval_logger
=
logging
.
getLogger
(
__name__
)
def
__getitem__
(
self
,
alias
:
str
)
->
T
:
# noqa
return
self
.
get
(
alias
)
MODEL_REGISTRY
=
{}
def
__iter__
(
self
):
# noqa
return
iter
(
self
.
_objects
)
def
__len__
(
self
)
->
int
:
# noqa
return
len
(
self
.
_objects
)
def
register_model
(
*
names
):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def
items
(
self
):
# noqa
return
self
.
_objects
.
items
()
def
decorate
(
cls
):
for
name
in
names
:
assert
issubclass
(
cls
,
LM
),
(
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
)
# Introspection -----------------------------------------------------------
assert
name
not
in
MODEL_REGISTRY
,
(
f
"Model named '
{
name
}
' conflicts with existing model! Please register with a non-conflicting alias instead."
)
def
origin
(
self
,
alias
:
str
)
->
str
|
None
:
obj
=
self
.
_objects
.
get
(
alias
)
try
:
if
isinstance
(
obj
,
str
)
or
isinstance
(
obj
,
md
.
EntryPoint
):
return
None
# placeholder - unknown until imported
file
=
inspect
.
getfile
(
obj
)
# type: ignore[arg-type]
line
=
inspect
.
getsourcelines
(
obj
)[
1
]
# type: ignore[arg-type]
return
f
"
{
file
}
:
{
line
}
"
except
(
TypeError
,
OSError
,
AttributeError
,
):
# pragma: no cover - best-effort only
# TypeError: object not suitable for inspect
# OSError: file not found or accessible
# AttributeError: object lacks expected attributes
return
None
MODEL_REGISTRY
[
name
]
=
cls
return
cls
def
get_metadata
(
self
,
alias
:
str
)
->
dict
[
str
,
Any
]
|
None
:
"""Get metadata for a registered item."""
with
self
.
_lock
:
return
self
.
_metadata
.
get
(
alias
)
return
decorate
# Mutability --------------------------------------------------------------
def
freeze
(
self
):
"""Make the registry *names* immutable (materialisation still works)."""
with
self
.
_lock
:
if
isinstance
(
self
.
_objects
,
MappingProxyType
):
return
# already frozen
self
.
_objects
=
MappingProxyType
(
dict
(
self
.
_objects
))
# type: ignore[assignment]
def
get_model
(
model_name
):
try
:
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
:
raise
ValueError
(
f
"Attempted to load model '
{
model_name
}
', but no model for this name found! Supported model names:
{
', '
.
join
(
MODEL_REGISTRY
.
keys
())
}
"
)
def
clear
(
self
):
"""Clear the registry (useful for tests). Cannot be called on frozen registries."""
with
self
.
_lock
:
if
isinstance
(
self
.
_objects
,
MappingProxyType
):
raise
RuntimeError
(
"Cannot clear a frozen registry"
)
self
.
_objects
.
clear
()
self
.
_metadata
.
clear
()
self
.
_materialise
.
cache_clear
()
# type: ignore[attr-defined] # Added by lru_cache
TASK_REGISTRY
=
{}
GROUP_REGISTRY
=
{}
ALL_TASKS
=
set
()
func2task_index
=
{}
# ────────────────────────────────────────────────────────────────────────
# Structured objects stored in registries
# ────────────────────────────────────────────────────────────────────────
def
register_task
(
name
):
def
decorate
(
fn
):
assert
name
not
in
TASK_REGISTRY
,
(
f
"task named '
{
name
}
' conflicts with existing registered task!"
)
@
dataclass
(
frozen
=
True
)
class
MetricSpec
:
"""Bundle compute fn, aggregator, and *higher‑is‑better* flag."""
TASK_REGISTRY
[
name
]
=
fn
ALL_TASKS
.
add
(
name
)
func2task_index
[
fn
.
__name__
]
=
name
return
fn
compute
:
Callable
[[
Any
,
Any
],
Any
]
aggregate
:
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]
higher_is_better
:
bool
=
True
output_type
:
str
|
None
=
None
# e.g., "probability", "string", "numeric"
requires
:
list
[
str
]
|
None
=
None
# Dependencies on other metrics/data
return
decorate
# ────────────────────────────────────────────────────────────────────────
# Concrete registries used by lm_eval
# ────────────────────────────────────────────────────────────────────────
def
register_group
(
name
):
def
decorate
(
fn
):
func_name
=
func2task_index
[
fn
.
__name__
]
if
name
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
name
].
append
(
func_name
)
else
:
GROUP_REGISTRY
[
name
]
=
[
func_name
]
ALL_TASKS
.
add
(
name
)
return
fn
return
decorate
from
lm_eval.api.model
import
LM
# noqa: E402
OUTPUT_TYPE_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
AGGREGATION_REGISTRY
:
Dict
[
str
,
Callable
[[],
Dict
[
str
,
Callable
]]]
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
model_registry
:
Registry
[
type
[
LM
]]
=
Registry
(
"model"
,
base_cls
=
LM
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]]
=
(
Registry
(
"metric aggregation"
)
)
higher_is_better_registry
:
Registry
[
bool
]
=
Registry
(
"higher‑is‑better flag"
)
filter_registry
:
Registry
[
Callable
]
=
Registry
(
"filter"
)
# Default metric registry for output types
DEFAULT_METRIC_REGISTRY
=
{
"loglikelihood"
:
[
"perplexity"
,
...
...
@@ -90,107 +347,194 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until"
:
[
"exact_match"
],
}
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY
:
dict
[
str
,
Callable
]
=
{}
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
# ────────────────────────────────────────────────────────────────────────
register_model
=
model_registry
.
register
get_model
=
model_registry
.
get
register_task
=
task_registry
.
register
get_task
=
task_registry
.
get
# Special handling for metric registration which uses different API
def
register_metric
(
**
kwargs
):
"""Register a metric with metadata.
Compatible with old registry API that used keyword arguments.
"""
def
register_metric
(
**
args
):
# TODO: do we want to enforce a certain interface to registered metrics?
def
decorate
(
fn
):
assert
"metric"
in
args
name
=
args
[
"metric"
]
for
key
,
registry
in
[
(
"metric"
,
METRIC_REGISTRY
),
(
"higher_is_better"
,
HIGHER_IS_BETTER_REGISTRY
),
(
"aggregation"
,
METRIC_AGGREGATION_REGISTRY
),
]:
if
key
in
args
:
value
=
args
[
key
]
assert
value
not
in
registry
,
(
f
"
{
key
}
named '
{
value
}
' conflicts with existing registered
{
key
}
!"
metric_name
=
kwargs
.
get
(
"metric"
)
if
not
metric_name
:
raise
ValueError
(
"metric name is required"
)
# Create MetricSpec with the function and metadata
spec
=
MetricSpec
(
compute
=
fn
,
aggregate
=
lambda
x
:
{},
# Default aggregation returns empty dict
higher_is_better
=
kwargs
.
get
(
"higher_is_better"
,
True
),
output_type
=
kwargs
.
get
(
"output_type"
),
requires
=
kwargs
.
get
(
"requires"
),
)
# Register in metric registry
metric_registry
.
_objects
[
metric_name
]
=
spec
# Also handle aggregation if specified
if
"aggregation"
in
kwargs
:
agg_name
=
kwargs
[
"aggregation"
]
# Try to get aggregation from AGGREGATION_REGISTRY
if
agg_name
in
AGGREGATION_REGISTRY
:
spec
=
MetricSpec
(
compute
=
fn
,
aggregate
=
AGGREGATION_REGISTRY
[
agg_name
],
higher_is_better
=
kwargs
.
get
(
"higher_is_better"
,
True
),
output_type
=
kwargs
.
get
(
"output_type"
),
requires
=
kwargs
.
get
(
"requires"
),
)
metric_registry
.
_objects
[
metric_name
]
=
spec
if
key
==
"metric"
:
registry
[
name
]
=
fn
elif
key
==
"aggregation"
:
registry
[
name
]
=
AGGREGATION_REGISTRY
[
value
]
else
:
registry
[
name
]
=
value
# Handle higher_is_better registry
if
"higher_is_better"
in
kwargs
:
higher_is_better_registry
.
_objects
[
metric_name
]
=
kwargs
[
"higher_is_better"
]
return
fn
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
):
"""Get a metric by name, with fallback to HF evaluate."""
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
else
:
eval_logger
.
warning
(
try
:
spec
=
metric_registry
.
get
(
name
)
if
isinstance
(
spec
,
MetricSpec
):
return
spec
.
compute
return
spec
except
KeyError
:
import
logging
logging
.
getLogger
(
__name__
).
warning
(
f
"Could not find registered metric '
{
name
}
' in lm-eval, searching in HF Evaluate library..."
)
# Fallback to HF evaluate
try
:
import
evaluate
as
hf_evaluate
metric_object
=
hf_evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
eval_logger
.
error
(
import
logging
logging
.
getLogger
(
__name__
).
error
(
f
"
{
name
}
not found in the evaluate library! Please check https://huggingface.co/evaluate-metric"
,
)
return
None
def
register_aggregation
(
name
:
str
):
def
decorate
(
fn
):
assert
name
not
in
AGGREGATION_REGISTRY
,
(
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
)
register_metric_aggregation
=
metric_agg_registry
.
register
AGGREGATION_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
def
get_metric_aggregation
(
metric_name
:
str
):
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
if
metric_name
in
metric_registry
.
_objects
:
metric_spec
=
metric_registry
.
_objects
[
metric_name
]
if
isinstance
(
metric_spec
,
MetricSpec
)
and
metric_spec
.
aggregate
:
return
metric_spec
.
aggregate
# Fall back to metric_agg_registry (for standalone aggregations)
if
metric_name
in
metric_agg_registry
.
_objects
:
return
metric_agg_registry
.
_objects
[
metric_name
]
def
get_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
# If not found, raise error
raise
KeyError
(
f
"Unknown metric aggregation '
{
metric_name
}
'. Available:
{
list
(
AGGREGATION_REGISTRY
.
keys
())
}
"
)
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
register_higher_is_better
=
higher_is_better_registry
.
register
is_higher_better
=
higher_is_better_registry
.
get
register_filter
=
filter_registry
.
register
get_filter
=
filter_registry
.
get
def
is_higher_better
(
metric_name
)
->
bool
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
eval_logger
.
warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!"
)
def
register_
filter
(
name
):
def
decorate
(
cls
):
if
name
in
FILTER
_REGISTRY
:
eval_logger
.
info
(
f
"
Registering filter `
{
name
}
` that is already in Registry
{
FILTER_REGISTRY
}
"
# Special handling for AGGREGATION_REGISTRY which works differently
def
register_
aggregation
(
name
:
str
):
def
decorate
(
fn
):
if
name
in
AGGREGATION
_REGISTRY
:
raise
ValueError
(
f
"
aggregation named '
{
name
}
' conflicts with existing registered aggregation!
"
)
FILTER
_REGISTRY
[
name
]
=
cls
return
cls
AGGREGATION
_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
def
get_
filter
(
filter_name
:
Union
[
str
,
Callable
])
->
Callable
:
def
get_
aggregation
(
name
:
str
)
->
Callable
[[],
dict
[
str
,
Callable
]]
:
try
:
return
FILTER_REGISTRY
[
filter_name
]
except
KeyError
as
e
:
if
callable
(
filter_name
):
return
filter_name
else
:
eval_logger
.
warning
(
f
"filter `
{
filter_name
}
` is not registered!"
)
raise
e
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
import
logging
logging
.
getLogger
(
__name__
).
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
return
None
# ────────────────────────────────────────────────────────────────────────
# Optional PyPI entry‑point discovery - uncomment if desired
# ────────────────────────────────────────────────────────────────────────
# for _group, _reg in {
# "lm_eval.models": model_registry,
# "lm_eval.tasks": task_registry,
# "lm_eval.metrics": metric_registry,
# }.items():
# for _ep in md.entry_points(group=_group):
# _reg.register(_ep.name, lazy=_ep)
# ────────────────────────────────────────────────────────────────────────
# Convenience
# ────────────────────────────────────────────────────────────────────────
def
freeze_all
()
->
None
:
# pragma: no cover
"""Freeze every global registry (idempotent)."""
for
_reg
in
(
model_registry
,
task_registry
,
metric_registry
,
metric_agg_registry
,
higher_is_better_registry
,
filter_registry
,
):
_reg
.
freeze
()
# ────────────────────────────────────────────────────────────────────────
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY
:
Mapping
[
str
,
type
[
LM
]]
=
MappingProxyType
(
model_registry
.
_objects
)
# type: ignore[attr-defined]
TASK_REGISTRY
:
Mapping
[
str
,
Callable
[...,
Any
]]
=
MappingProxyType
(
task_registry
.
_objects
)
# type: ignore[attr-defined]
METRIC_REGISTRY
:
Mapping
[
str
,
MetricSpec
]
=
MappingProxyType
(
metric_registry
.
_objects
)
# type: ignore[attr-defined]
METRIC_AGGREGATION_REGISTRY
:
Mapping
[
str
,
Callable
]
=
MappingProxyType
(
metric_agg_registry
.
_objects
)
# type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY
:
Mapping
[
str
,
bool
]
=
MappingProxyType
(
higher_is_better_registry
.
_objects
)
# type: ignore[attr-defined]
FILTER_REGISTRY
:
Mapping
[
str
,
Callable
]
=
MappingProxyType
(
filter_registry
.
_objects
)
# type: ignore[attr-defined]
lm_eval/api/task.py
View file @
93b2ab37
...
...
@@ -3,18 +3,15 @@ import ast
import
logging
import
random
import
re
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Iterable
,
Iterator
,
Mapping
from
copy
import
deepcopy
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
typing
import
(
Any
,
Dict
,
Iterable
,
Iterator
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
Union
,
...
...
@@ -1774,7 +1771,7 @@ class MultipleChoiceTask(Task):
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
(
ctx
,
" {
}"
.
format
(
choice
)
),
arguments
=
(
ctx
,
f
"
{
choice
}
"
),
idx
=
i
,
**
kwargs
,
)
...
...
lm_eval/models/__init__.py
View file @
93b2ab37
from
.
import
(
anthropic_llms
,
api_models
,
dummy
,
gguf
,
hf_audiolm
,
hf_steered
,
hf_vlms
,
huggingface
,
ibm_watsonx_ai
,
mamba_lm
,
nemo_lm
,
neuron_optimum
,
openai_completions
,
optimum_ipex
,
optimum_lm
,
sglang_causallms
,
sglang_generate_API
,
textsynth
,
vllm_causallms
,
vllm_vlms
,
)
# TODO: implement __all__
# Models are now lazily loaded via the registry system
# No need to import them all at once
# Define model mappings for lazy registration
MODEL_MAPPING
=
{
"anthropic-completions"
:
"lm_eval.models.anthropic_llms:AnthropicLM"
,
"anthropic-chat"
:
"lm_eval.models.anthropic_llms:AnthropicChatLM"
,
"anthropic-chat-completions"
:
"lm_eval.models.anthropic_llms:AnthropicCompletionsLM"
,
"local-completions"
:
"lm_eval.models.openai_completions:LocalCompletionsAPI"
,
"local-chat-completions"
:
"lm_eval.models.openai_completions:LocalChatCompletion"
,
"openai-completions"
:
"lm_eval.models.openai_completions:OpenAICompletionsAPI"
,
"openai-chat-completions"
:
"lm_eval.models.openai_completions:OpenAIChatCompletion"
,
"dummy"
:
"lm_eval.models.dummy:DummyLM"
,
"gguf"
:
"lm_eval.models.gguf:GGUFLM"
,
"ggml"
:
"lm_eval.models.gguf:GGUFLM"
,
"hf-audiolm-qwen"
:
"lm_eval.models.hf_audiolm:HFAudioLM"
,
"steered"
:
"lm_eval.models.hf_steered:SteeredHF"
,
"hf-multimodal"
:
"lm_eval.models.hf_vlms:HFMultimodalLM"
,
"hf-auto"
:
"lm_eval.models.huggingface:HFLM"
,
"hf"
:
"lm_eval.models.huggingface:HFLM"
,
"huggingface"
:
"lm_eval.models.huggingface:HFLM"
,
"watsonx_llm"
:
"lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI"
,
"mamba_ssm"
:
"lm_eval.models.mamba_lm:MambaLMWrapper"
,
"nemo_lm"
:
"lm_eval.models.nemo_lm:NeMoLM"
,
"neuronx"
:
"lm_eval.models.neuron_optimum:NeuronModelForCausalLM"
,
"ipex"
:
"lm_eval.models.optimum_ipex:IPEXForCausalLM"
,
"openvino"
:
"lm_eval.models.optimum_lm:OptimumLM"
,
"sglang"
:
"lm_eval.models.sglang_causallms:SGLANG"
,
"sglang-generate"
:
"lm_eval.models.sglang_generate_API:SGAPI"
,
"textsynth"
:
"lm_eval.models.textsynth:TextSynthLM"
,
"vllm"
:
"lm_eval.models.vllm_causallms:VLLM"
,
"vllm-vlm"
:
"lm_eval.models.vllm_vlms:VLLM_VLM"
,
}
# Register all models lazily
def
_register_all_models
():
"""Register all known models lazily in the registry."""
from
lm_eval.api.registry
import
model_registry
for
name
,
path
in
MODEL_MAPPING
.
items
():
# Only register if not already present (avoids conflicts when modules are imported)
if
name
not
in
model_registry
:
# Call register with the lazy parameter, returns a decorator
model_registry
.
register
(
name
,
lazy
=
path
)(
None
)
# Call registration on module import
_register_all_models
()
__all__
=
[
"MODEL_MAPPING"
]
try
:
...
...
lm_eval/models/hf_steered.py
View file @
93b2ab37
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
functools
import
partial
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
from
peft.peft_model
import
PeftModel
...
...
lm_eval/models/ibm_watsonx_ai.py
View file @
93b2ab37
...
...
@@ -3,7 +3,7 @@ import json
import
logging
import
os
import
warnings
from
functools
import
lru_
cache
from
functools
import
cache
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
cast
from
tqdm
import
tqdm
...
...
@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise
ValueError
(
error_msg
)
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
get_watsonx_credentials
()
->
Dict
[
str
,
str
]:
"""
Retrieves Watsonx API credentials from environmental variables.
...
...
lm_eval/models/vllm_causallms.py
View file @
93b2ab37
...
...
@@ -40,7 +40,7 @@ try:
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
from
vllm.entrypoints.chat_utils
import
resolve_hf_chat_template
except
ModuleNotFoundError
:
p
ass
p
rint
(
"njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd"
)
if
TYPE_CHECKING
:
pass
...
...
lm_eval/tasks/acpbench/gen_2shot/acp_utils.py
View file @
93b2ab37
...
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
...
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
lm_eval/tasks/acpbench/gen_2shot_with_pddl/acp_utils.py
View file @
93b2ab37
...
...
@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self
.
indexes
=
None
class
ACPGrammarParser
(
object
)
:
class
ACPGrammarParser
:
def
__init__
(
self
,
task
)
->
None
:
self
.
task
=
task
with
open
(
GRAMMAR_FILE
)
as
f
:
...
...
@@ -556,8 +556,8 @@ class STRIPS:
return
set
([
fix_name
(
str
(
x
))
for
x
in
ret
])
def
PDDL_replace_init_pddl_parser
(
self
,
s
):
d
=
DomainParser
()(
open
(
self
.
domain_file
,
"r"
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
,
"r"
).
read
().
lower
())
d
=
DomainParser
()(
open
(
self
.
domain_file
).
read
().
lower
())
p
=
ProblemParser
()(
open
(
self
.
problem_file
).
read
().
lower
())
new_state
=
get_atoms_pddl
(
d
,
p
,
s
|
self
.
get_static
())
...
...
scripts/build_benchmark.py
View file @
93b2ab37
...
...
@@ -7,7 +7,7 @@ from promptsource.templates import DatasetTemplates
from
tqdm
import
tqdm
# from lm_eval.api.registry import ALL_TASKS
# from lm_eval.api.registry
v2
import ALL_TASKS
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment