Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e9451269
Commit
e9451269
authored
Jul 28, 2025
by
Baber
Browse files
cleanup and and add types
parent
48eabc04
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
85 deletions
+70
-85
lm_eval/api/registry.py
lm_eval/api/registry.py
+69
-84
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+1
-1
No files found.
lm_eval/api/registry.py
View file @
e9451269
from
__future__
import
annotations
from
__future__
import
annotations
import
functools
import
importlib
import
importlib
import
inspect
import
inspect
import
threading
import
threading
...
@@ -13,7 +14,7 @@ from typing import (
...
@@ -13,7 +14,7 @@ from typing import (
Callable
,
Callable
,
Generic
,
Generic
,
TypeVar
,
TypeVar
,
overload
,
cast
,
)
)
...
@@ -22,19 +23,8 @@ try: # Python≥3.10
...
@@ -22,19 +23,8 @@ try: # Python≥3.10
except
ImportError
:
# pragma: no cover - fallback for 3.8/3.9 runtimes
except
ImportError
:
# pragma: no cover - fallback for 3.8/3.9 runtimes
import
importlib_metadata
as
md
# type: ignore
import
importlib_metadata
as
md
# type: ignore
__all__
=
[
# Legacy exports (keep for one release, then drop)
"Registry"
,
LEGACY_EXPORTS
=
[
"MetricSpec"
,
# concrete registries
"model_registry"
,
"task_registry"
,
"metric_registry"
,
"metric_agg_registry"
,
"higher_is_better_registry"
,
"filter_registry"
,
# helper
"freeze_all"
,
# Legacy compatibility
"DEFAULT_METRIC_REGISTRY"
,
"DEFAULT_METRIC_REGISTRY"
,
"AGGREGATION_REGISTRY"
,
"AGGREGATION_REGISTRY"
,
"register_model"
,
"register_model"
,
...
@@ -59,6 +49,21 @@ __all__ = [
...
@@ -59,6 +49,21 @@ __all__ = [
"FILTER_REGISTRY"
,
"FILTER_REGISTRY"
,
]
]
__all__
=
[
# canonical
"Registry"
,
"MetricSpec"
,
"model_registry"
,
"task_registry"
,
"metric_registry"
,
"metric_agg_registry"
,
"higher_is_better_registry"
,
"filter_registry"
,
"freeze_all"
,
# legacy
*
LEGACY_EXPORTS
,
]
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
...
@@ -94,25 +99,25 @@ class Registry(Generic[T]):
...
@@ -94,25 +99,25 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call)
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
# ------------------------------------------------------------------
@
overload
#
@overload
def
register
(
#
def register(
self
,
#
self,
*
aliases
:
str
,
#
*aliases: str,
lazy
:
None
=
None
,
#
lazy: None = None,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
#
metadata: dict[str, Any] | None = None,
)
->
Callable
[[
T
],
T
]:
#
) -> Callable[[T], T]:
"""Register as decorator: @registry.register("foo")."""
#
"""Register as decorator: @registry.register("foo")."""
...
#
...
#
@
overload
#
@overload
def
register
(
#
def register(
self
,
#
self,
*
aliases
:
str
,
#
*aliases: str,
lazy
:
str
|
md
.
EntryPoint
,
#
lazy: str | md.EntryPoint,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
#
metadata: dict[str, Any] | None = None,
)
->
Callable
[[
Any
],
Any
]:
#
) -> Callable[[Any], Any]:
"""Register lazy: registry.register("foo", lazy="a.b:C")
(None).
"""
#
"""Register lazy: registry.register("foo", lazy="a.b:C")"""
...
#
...
def
_resolve_aliases
(
def
_resolve_aliases
(
self
,
target
:
T
|
str
|
md
.
EntryPoint
,
aliases
:
tuple
[
str
,
...]
self
,
target
:
T
|
str
|
md
.
EntryPoint
,
aliases
:
tuple
[
str
,
...]
...
@@ -185,47 +190,25 @@ class Registry(Generic[T]):
...
@@ -185,47 +190,25 @@ class Registry(Generic[T]):
def
register
(
def
register
(
self
,
self
,
*
aliases
:
str
,
*
aliases
:
str
,
obj
:
T
|
None
=
None
,
lazy
:
str
|
md
.
EntryPoint
|
None
=
None
,
lazy
:
str
|
md
.
EntryPoint
|
None
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
):
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
if
obj
and
lazy
:
raise
ValueError
(
"pass obj *or* lazy"
)
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
# ─── direct‑call path with lazy placeholder ───
if
lazy
is
not
None
:
for
alias
in
self
.
_resolve_aliases
(
lazy
,
aliases
):
self
.
_check_and_store
(
alias
,
lazy
,
metadata
)
return
lambda
x
:
x
# no‑op decorator for accidental use
# ─── decorator path ───
@
functools
.
wraps
(
self
.
register
)
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
def
_impl
(
target
:
T
|
str
|
md
.
EntryPoint
):
for
a
lias
in
self
.
_resolve_aliases
(
obj
,
aliases
):
for
a
in
aliases
or
(
getattr
(
target
,
"__name__"
,
str
(
target
)),
):
self
.
_check_and_store
(
a
lias
,
obj
,
metadata
)
self
.
_check_and_store
(
a
,
target
,
metadata
)
return
obj
return
target
return
decorator
# imperative call → immediately registers and returns the target
if
obj
is
not
None
or
lazy
is
not
None
:
return
_impl
(
obj
if
obj
is
not
None
else
lazy
)
# type: ignore[arg-type]
# def register_bulk(
# decorator call → return function that will later receive the object
# self,
return
_impl
# items: dict[str, T | str | md.EntryPoint],
# metadata: dict[str, dict[str, Any]] | None = None,
# ) -> None:
# """Register multiple items at once.
#
# Args:
# items: Dictionary mapping aliases to objects/lazy paths
# metadata: Optional dictionary mapping aliases to metadata
# """
# for alias, target in items.items():
# meta = metadata.get(alias, {}) if metadata else {}
# # For lazy registration, check if it's a string or EntryPoint
# if isinstance(target, (str, md.EntryPoint)):
# self.register(alias, lazy=target, metadata=meta)(None)
# else:
# self.register(alias, metadata=meta)(target)
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Lookup & materialisation
# Lookup & materialisation
...
@@ -241,9 +224,9 @@ class Registry(Generic[T]):
...
@@ -241,9 +224,9 @@ class Registry(Generic[T]):
f
"Lazy path '
{
target
}
' must be in 'module:object' form"
f
"Lazy path '
{
target
}
' must be in 'module:object' form"
)
)
module
=
importlib
.
import_module
(
mod
)
module
=
importlib
.
import_module
(
mod
)
return
getattr
(
module
,
obj_name
)
return
cast
(
T
,
getattr
(
module
,
obj_name
)
)
if
isinstance
(
target
,
md
.
EntryPoint
):
if
isinstance
(
target
,
md
.
EntryPoint
):
return
target
.
load
()
return
cast
(
T
,
target
.
load
()
)
return
target
# concrete already
return
target
# concrete already
def
get
(
self
,
alias
:
str
)
->
T
:
def
get
(
self
,
alias
:
str
)
->
T
:
...
@@ -263,14 +246,14 @@ class Registry(Generic[T]):
...
@@ -263,14 +246,14 @@ class Registry(Generic[T]):
f
"
{
', '
.
join
(
self
.
_objects
)
}
"
f
"
{
', '
.
join
(
self
.
_objects
)
}
"
)
from
exc
)
from
exc
# Double-check after acquiring lock (may have been materialized by another thread)
# Double-check after acquiring
a
lock (may have been materialized by another thread)
if
not
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
if
not
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
return
target
return
target
# Materialize the lazy placeholder
# Materialize the lazy placeholder
concrete
:
T
=
self
.
_materialise
(
target
)
concrete
:
T
=
self
.
_materialise
(
target
)
# Swap placeholder with concrete object (with race condition check)
# Swap placeholder with
a
concrete object (with race condition check)
if
concrete
is
not
target
:
if
concrete
is
not
target
:
# Final check: another thread might have materialized while we were working
# Final check: another thread might have materialized while we were working
current
=
self
.
_objects
.
get
(
alias
)
current
=
self
.
_objects
.
get
(
alias
)
...
@@ -405,7 +388,7 @@ def default_metrics_for(output_type: str) -> list[str]:
...
@@ -405,7 +388,7 @@ def default_metrics_for(output_type: str) -> list[str]:
This walks the metric registry to find metrics that match the output type.
This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
"""
"""
# First check static defaults
# First
,
check static defaults
if
output_type
in
DEFAULT_METRIC_REGISTRY
:
if
output_type
in
DEFAULT_METRIC_REGISTRY
:
return
DEFAULT_METRIC_REGISTRY
[
output_type
]
return
DEFAULT_METRIC_REGISTRY
[
output_type
]
...
@@ -448,7 +431,7 @@ def register_metric(**kwargs):
...
@@ -448,7 +431,7 @@ def register_metric(**kwargs):
raise
ValueError
(
"metric name is required"
)
raise
ValueError
(
"metric name is required"
)
# Determine aggregation function
# Determine aggregation function
aggregate_fn
=
None
aggregate_fn
:
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]
|
None
=
None
if
"aggregation"
in
kwargs
:
if
"aggregation"
in
kwargs
:
agg_name
=
kwargs
[
"aggregation"
]
agg_name
=
kwargs
[
"aggregation"
]
try
:
try
:
...
@@ -474,12 +457,12 @@ def register_metric(**kwargs):
...
@@ -474,12 +457,12 @@ def register_metric(**kwargs):
requires
=
kwargs
.
get
(
"requires"
),
requires
=
kwargs
.
get
(
"requires"
),
)
)
# Use proper registry API with metadata
# Use
a
proper registry API with metadata
metric_registry
.
register
(
metric_name
,
metadata
=
kwargs
)(
spec
)
metric_registry
.
register
(
metric_name
,
metadata
=
kwargs
)(
spec
)
# type: ignore[misc]
# Also register in higher_is_better registry if specified
# Also register in higher_is_better registry if specified
if
"higher_is_better"
in
kwargs
:
if
"higher_is_better"
in
kwargs
:
higher_is_better_registry
.
register
(
metric_name
)(
kwargs
[
"higher_is_better"
])
higher_is_better_registry
.
register
(
metric_name
)(
kwargs
[
"higher_is_better"
])
# type: ignore[misc]
return
fn
return
fn
...
@@ -519,15 +502,17 @@ def get_metric(name: str, hf_evaluate_metric=False):
...
@@ -519,15 +502,17 @@ def get_metric(name: str, hf_evaluate_metric=False):
register_metric_aggregation
=
metric_agg_registry
.
register
register_metric_aggregation
=
metric_agg_registry
.
register
def
get_metric_aggregation
(
metric_name
:
str
):
def
get_metric_aggregation
(
metric_name
:
str
,
)
->
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]:
"""Get the aggregation function for a metric."""
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
# First
,
try to get from
the
metric registry (for metrics registered with aggregation)
try
:
try
:
metric_spec
=
metric_registry
.
get
(
metric_name
)
metric_spec
=
metric_registry
.
get
(
metric_name
)
if
isinstance
(
metric_spec
,
MetricSpec
)
and
metric_spec
.
aggregate
:
if
isinstance
(
metric_spec
,
MetricSpec
)
and
metric_spec
.
aggregate
:
return
metric_spec
.
aggregate
return
metric_spec
.
aggregate
except
KeyError
:
except
KeyError
:
pass
# Try next registry
pass
# Try
the
next registry
# Fall back to metric_agg_registry (for standalone aggregations)
# Fall back to metric_agg_registry (for standalone aggregations)
try
:
try
:
...
@@ -535,7 +520,7 @@ def get_metric_aggregation(metric_name: str):
...
@@ -535,7 +520,7 @@ def get_metric_aggregation(metric_name: str):
except
KeyError
:
except
KeyError
:
pass
pass
# If not found, raise error
# If not found, raise
an
error
raise
KeyError
(
raise
KeyError
(
f
"Unknown metric aggregation '
{
metric_name
}
'. Available:
{
list
(
metric_agg_registry
)
}
"
f
"Unknown metric aggregation '
{
metric_name
}
'. Available:
{
list
(
metric_agg_registry
)
}
"
)
)
...
@@ -558,12 +543,12 @@ def register_aggregation(name: str):
...
@@ -558,12 +543,12 @@ def register_aggregation(name: str):
)
)
def
decorate
(
fn
):
def
decorate
(
fn
):
# Use the canonical registry as single source of truth
# Use the canonical registry as
a
single source of truth
if
name
in
metric_agg_registry
:
if
name
in
metric_agg_registry
:
raise
ValueError
(
raise
ValueError
(
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
)
)
metric_agg_registry
.
register
(
name
)(
fn
)
metric_agg_registry
.
register
(
name
)(
fn
)
# type: ignore[misc]
return
fn
return
fn
return
decorate
return
decorate
...
...
lm_eval/models/__init__.py
View file @
e9451269
...
@@ -42,7 +42,7 @@ def _register_all_models():
...
@@ -42,7 +42,7 @@ def _register_all_models():
# Only register if not already present (avoids conflicts when modules are imported)
# Only register if not already present (avoids conflicts when modules are imported)
if
name
not
in
model_registry
:
if
name
not
in
model_registry
:
# Call register with the lazy parameter, returns a decorator
# Call register with the lazy parameter, returns a decorator
model_registry
.
register
(
name
,
lazy
=
path
)
(
None
)
model_registry
.
register
(
name
,
lazy
=
path
)
# Call registration on module import
# Call registration on module import
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment