Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
48eabc04
Commit
48eabc04
authored
Jul 28, 2025
by
Baber
Browse files
add better type safety
parent
93b2ab37
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
227 additions
and
133 deletions
+227
-133
lm_eval/api/registry.py
lm_eval/api/registry.py
+227
-133
No files found.
lm_eval/api/registry.py
View file @
48eabc04
...
...
@@ -3,6 +3,7 @@ from __future__ import annotations
import
importlib
import
inspect
import
threading
import
warnings
from
collections.abc
import
Iterable
,
Mapping
,
MutableMapping
from
dataclasses
import
dataclass
from
functools
import
lru_cache
...
...
@@ -12,6 +13,7 @@ from typing import (
Callable
,
Generic
,
TypeVar
,
overload
,
)
...
...
@@ -92,104 +94,138 @@ class Registry(Generic[T]):
# Registration helpers (decorator or direct call)
# ------------------------------------------------------------------
@
overload
def
register
(
self
,
*
aliases
:
str
,
lazy
:
str
|
md
.
EntryPoint
|
None
=
None
,
lazy
:
None
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
def
_do_register
(
target
:
T
|
str
|
md
.
EntryPoint
)
->
None
:
if
not
aliases
:
_aliases
=
(
getattr
(
target
,
"__name__"
,
str
(
target
)),)
else
:
_aliases
=
aliases
with
self
.
_lock
:
for
alias
in
_aliases
:
if
alias
in
self
.
_objects
:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing
=
self
.
_objects
[
alias
]
if
isinstance
(
existing
,
(
str
,
md
.
EntryPoint
))
and
isinstance
(
target
,
type
):
# Allow replacing lazy placeholder with concrete class
pass
else
:
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
'
{
alias
}
' already registered "
f
"(
{
self
.
_objects
[
alias
]
}
)"
)
# Eager type check only when we have a concrete class
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
"
f
"to be registered as a
{
self
.
_name
}
"
)
self
.
_objects
[
alias
]
=
target
# Store metadata if provided
if
metadata
:
self
.
_metadata
[
alias
]
=
metadata
"""Register as decorator: @registry.register("foo")."""
...
# ─── decorator path ───
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
_do_register
(
obj
)
return
obj
# ─── direct‑call path with lazy placeholder ───
if
lazy
is
not
None
:
_do_register
(
lazy
)
return
lambda
x
:
x
# no‑op decorator for accidental use
return
decorator
def
register_bulk
(
@
overload
def
register
(
self
,
items
:
dict
[
str
,
T
|
str
|
md
.
EntryPoint
],
metadata
:
dict
[
str
,
dict
[
str
,
Any
]]
|
None
=
None
,
*
aliases
:
str
,
lazy
:
str
|
md
.
EntryPoint
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Callable
[[
Any
],
Any
]:
"""Register lazy: registry.register("foo", lazy="a.b:C")(None)."""
...
def
_resolve_aliases
(
self
,
target
:
T
|
str
|
md
.
EntryPoint
,
aliases
:
tuple
[
str
,
...]
)
->
tuple
[
str
,
...]:
"""Resolve aliases for registration."""
if
not
aliases
:
return
(
getattr
(
target
,
"__name__"
,
str
(
target
)),)
return
aliases
def
_check_and_store
(
self
,
alias
:
str
,
target
:
T
|
str
|
md
.
EntryPoint
,
metadata
:
dict
[
str
,
Any
]
|
None
,
)
->
None
:
"""
Register multiple items at once
.
"""
Check constraints and store the target with optional metadata
.
Args:
items: Dictionary mapping aliases to objects/lazy paths
metadata: Optional dictionary mapping aliases to metadata
Collision policy:
1. If alias doesn't exist → store it
2. If identical value → silently succeed (idempotent)
3. If lazy placeholder + matching concrete class → replace with concrete
4. Otherwise → raise ValueError
Type checking:
- Eager for concrete classes at registration time
- Deferred for lazy placeholders until materialization
"""
with
self
.
_lock
:
for
alias
,
target
in
items
.
items
():
if
alias
in
self
.
_objects
:
# If it's a lazy placeholder being replaced by the concrete object, allow it
existing
=
self
.
_objects
[
alias
]
if
isinstance
(
existing
,
(
str
,
md
.
EntryPoint
))
and
isinstance
(
target
,
type
):
# Allow replacing lazy placeholder with concrete class
pass
else
:
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
'
{
alias
}
' already registered "
f
"(
{
self
.
_objects
[
alias
]
}
)"
)
# Eager type check only when we have a concrete class
# Case 1: New alias
if
alias
not
in
self
.
_objects
:
# Type check concrete classes before storing
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
"
f
"to be registered as a
{
self
.
_name
}
"
)
self
.
_objects
[
alias
]
=
target
if
metadata
:
self
.
_metadata
[
alias
]
=
metadata
return
existing
=
self
.
_objects
[
alias
]
# Case 2: Identical value - idempotent
if
existing
==
target
:
return
# Case 3: Lazy placeholder being replaced by its concrete class
if
isinstance
(
existing
,
str
)
and
isinstance
(
target
,
type
):
mod_path
,
_
,
cls_name
=
existing
.
partition
(
":"
)
if
(
cls_name
and
hasattr
(
target
,
"__module__"
)
and
hasattr
(
target
,
"__name__"
)
):
expected_path
=
f
"
{
target
.
__module__
}
:
{
target
.
__name__
}
"
if
existing
==
expected_path
:
self
.
_objects
[
alias
]
=
target
if
metadata
:
self
.
_metadata
[
alias
]
=
metadata
return
# Case 4: Collision - different values
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
'
{
alias
}
' already registered "
f
"(existing:
{
existing
}
, new:
{
target
}
)"
)
def
register
(
self
,
*
aliases
:
str
,
lazy
:
str
|
md
.
EntryPoint
|
None
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
"""``@registry.register("foo")`` **or** ``registry.register("foo", lazy="a.b:C")``.
* If called as a **decorator**, supply an object and *no* ``lazy``.
* If called as a **plain function** and you want lazy import, leave the
object out and pass ``lazy=``.
"""
# ─── direct‑call path with lazy placeholder ───
if
lazy
is
not
None
:
for
alias
in
self
.
_resolve_aliases
(
lazy
,
aliases
):
self
.
_check_and_store
(
alias
,
lazy
,
metadata
)
return
lambda
x
:
x
# no‑op decorator for accidental use
# Store metadata if provided
if
metadata
and
alias
in
metadata
:
self
.
_metadata
[
alias
]
=
metadata
[
alias
]
# ─── decorator path ───
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
for
alias
in
self
.
_resolve_aliases
(
obj
,
aliases
):
self
.
_check_and_store
(
alias
,
obj
,
metadata
)
return
obj
return
decorator
# def register_bulk(
# self,
# items: dict[str, T | str | md.EntryPoint],
# metadata: dict[str, dict[str, Any]] | None = None,
# ) -> None:
# """Register multiple items at once.
#
# Args:
# items: Dictionary mapping aliases to objects/lazy paths
# metadata: Optional dictionary mapping aliases to metadata
# """
# for alias, target in items.items():
# meta = metadata.get(alias, {}) if metadata else {}
# # For lazy registration, check if it's a string or EntryPoint
# if isinstance(target, (str, md.EntryPoint)):
# self.register(alias, lazy=target, metadata=meta)(None)
# else:
# self.register(alias, metadata=meta)(target)
# ------------------------------------------------------------------
# Lookup & materialisation
...
...
@@ -211,6 +247,13 @@ class Registry(Generic[T]):
return
target
# concrete already
def
get
(
self
,
alias
:
str
)
->
T
:
# Fast path: check if already materialized without lock
target
=
self
.
_objects
.
get
(
alias
)
if
target
is
not
None
and
not
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
# Already materialized and validated, return immediately
return
target
# Slow path: acquire lock for materialization
with
self
.
_lock
:
try
:
target
=
self
.
_objects
[
alias
]
...
...
@@ -220,15 +263,23 @@ class Registry(Generic[T]):
f
"
{
', '
.
join
(
self
.
_objects
)
}
"
)
from
exc
# Only materialize if it's a string or EntryPoint (lazy placeholder)
if
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
concrete
:
T
=
self
.
_materialise
(
target
)
# First‑touch: swap placeholder with concrete obj for future calls
if
concrete
is
not
target
:
# Double-check after acquiring lock (may have been materialized by another thread)
if
not
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
return
target
# Materialize the lazy placeholder
concrete
:
T
=
self
.
_materialise
(
target
)
# Swap placeholder with concrete object (with race condition check)
if
concrete
is
not
target
:
# Final check: another thread might have materialized while we were working
current
=
self
.
_objects
.
get
(
alias
)
if
isinstance
(
current
,
(
str
,
md
.
EntryPoint
)):
# Still a placeholder, safe to replace
self
.
_objects
[
alias
]
=
concrete
else
:
# A
lready materialized,
j
us
t return i
t
concrete
=
target
else
:
# Another thread a
lready materialized
it
, us
e their resul
t
concrete
=
current
# type: ignore[assignment]
# Late type check (for placeholders)
if
self
.
_base_cls
is
not
None
and
not
issubclass
(
concrete
,
self
.
_base_cls
):
# type: ignore[arg-type]
...
...
@@ -237,8 +288,8 @@ class Registry(Generic[T]):
f
"(registered under alias '
{
alias
}
')"
)
# Custom validation
if
self
.
_validator
is
not
None
and
not
self
.
_validator
(
concrete
):
# Custom validation
- run on materialization
if
self
.
_validator
and
not
self
.
_validator
(
concrete
):
raise
ValueError
(
f
"
{
concrete
}
failed custom validation for
{
self
.
_name
}
registry "
f
"(registered under alias '
{
alias
}
')"
...
...
@@ -301,7 +352,7 @@ class Registry(Generic[T]):
raise
RuntimeError
(
"Cannot clear a frozen registry"
)
self
.
_objects
.
clear
()
self
.
_metadata
.
clear
()
self
.
_materialise
.
cache_clear
()
# type: ignore[attr-defined]
# Added by lru_cache
self
.
_materialise
.
cache_clear
()
# type: ignore[attr-defined]
# ────────────────────────────────────────────────────────────────────────
...
...
@@ -327,7 +378,7 @@ class MetricSpec:
from
lm_eval.api.model
import
LM
# noqa: E402
model_registry
:
Registry
[
type
[
LM
]
]
=
Registry
(
"model"
,
base_cls
=
LM
)
model_registry
:
Registry
[
LM
]
=
Registry
(
"model"
,
base_cls
=
LM
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]]
=
(
...
...
@@ -347,8 +398,31 @@ DEFAULT_METRIC_REGISTRY = {
"generate_until"
:
[
"exact_match"
],
}
# Aggregation registry (will be populated by register_aggregation)
AGGREGATION_REGISTRY
:
dict
[
str
,
Callable
]
=
{}
def
default_metrics_for
(
output_type
:
str
)
->
list
[
str
]:
"""Get default metrics for a given output type dynamically.
This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
"""
# First check static defaults
if
output_type
in
DEFAULT_METRIC_REGISTRY
:
return
DEFAULT_METRIC_REGISTRY
[
output_type
]
# Walk metric registry for matching output types
matching_metrics
=
[]
for
name
,
metric_spec
in
metric_registry
.
items
():
if
(
isinstance
(
metric_spec
,
MetricSpec
)
and
metric_spec
.
output_type
==
output_type
):
matching_metrics
.
append
(
name
)
return
matching_metrics
if
matching_metrics
else
[]
# Aggregation registry - alias to the canonical registry for backward compatibility
AGGREGATION_REGISTRY
=
metric_agg_registry
# The registry itself is dict-like
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
...
...
@@ -373,35 +447,39 @@ def register_metric(**kwargs):
if
not
metric_name
:
raise
ValueError
(
"metric name is required"
)
# Determine aggregation function
aggregate_fn
=
None
if
"aggregation"
in
kwargs
:
agg_name
=
kwargs
[
"aggregation"
]
try
:
aggregate_fn
=
metric_agg_registry
.
get
(
agg_name
)
except
KeyError
:
raise
ValueError
(
f
"Unknown aggregation:
{
agg_name
}
"
)
else
:
# No aggregation specified - use a function that raises NotImplementedError
def
not_implemented_agg
(
values
):
raise
NotImplementedError
(
f
"No aggregation function specified for metric '
{
metric_name
}
'. "
"Please specify an 'aggregation' parameter."
)
aggregate_fn
=
not_implemented_agg
# Create MetricSpec with the function and metadata
spec
=
MetricSpec
(
compute
=
fn
,
aggregate
=
lambda
x
:
{},
# Default aggregation returns empty dict
aggregate
=
aggregate_fn
,
higher_is_better
=
kwargs
.
get
(
"higher_is_better"
,
True
),
output_type
=
kwargs
.
get
(
"output_type"
),
requires
=
kwargs
.
get
(
"requires"
),
)
#
Register in metric registry
metric_registry
.
_objects
[
metric_name
]
=
spec
#
Use proper registry API with metadata
metric_registry
.
register
(
metric_name
,
metadata
=
kwargs
)(
spec
)
# Also handle aggregation if specified
if
"aggregation"
in
kwargs
:
agg_name
=
kwargs
[
"aggregation"
]
# Try to get aggregation from AGGREGATION_REGISTRY
if
agg_name
in
AGGREGATION_REGISTRY
:
spec
=
MetricSpec
(
compute
=
fn
,
aggregate
=
AGGREGATION_REGISTRY
[
agg_name
],
higher_is_better
=
kwargs
.
get
(
"higher_is_better"
,
True
),
output_type
=
kwargs
.
get
(
"output_type"
),
requires
=
kwargs
.
get
(
"requires"
),
)
metric_registry
.
_objects
[
metric_name
]
=
spec
# Handle higher_is_better registry
# Also register in higher_is_better registry if specified
if
"higher_is_better"
in
kwargs
:
higher_is_better_registry
.
_objects
[
metric_name
]
=
kwargs
[
"higher_is_better"
]
higher_is_better_registry
.
register
(
metric_name
)(
kwargs
[
"higher_is_better"
]
)
return
fn
...
...
@@ -444,18 +522,22 @@ register_metric_aggregation = metric_agg_registry.register
def
get_metric_aggregation
(
metric_name
:
str
):
"""Get the aggregation function for a metric."""
# First try to get from metric registry (for metrics registered with aggregation)
if
metric_name
in
metric_registry
.
_objects
:
metric_spec
=
metric_registry
.
_objects
[
metric_name
]
try
:
metric_spec
=
metric_registry
.
get
(
metric_name
)
if
isinstance
(
metric_spec
,
MetricSpec
)
and
metric_spec
.
aggregate
:
return
metric_spec
.
aggregate
except
KeyError
:
pass
# Try next registry
# Fall back to metric_agg_registry (for standalone aggregations)
if
metric_name
in
metric_agg_registry
.
_objects
:
return
metric_agg_registry
.
_objects
[
metric_name
]
try
:
return
metric_agg_registry
.
get
(
metric_name
)
except
KeyError
:
pass
# If not found, raise error
raise
KeyError
(
f
"Unknown metric aggregation '
{
metric_name
}
'. Available:
{
list
(
AGGREGATION_REGISTRY
.
keys
()
)
}
"
f
"Unknown metric aggregation '
{
metric_name
}
'. Available:
{
list
(
metric_agg_registry
)
}
"
)
...
...
@@ -468,20 +550,30 @@ get_filter = filter_registry.get
# Special handling for AGGREGATION_REGISTRY which works differently
def
register_aggregation
(
name
:
str
):
"""@deprecated Use metric_agg_registry.register() instead."""
warnings
.
warn
(
"register_aggregation() is deprecated. Use metric_agg_registry.register() instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
def
decorate
(
fn
):
if
name
in
AGGREGATION_REGISTRY
:
# Use the canonical registry as single source of truth
if
name
in
metric_agg_registry
:
raise
ValueError
(
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
)
AGGREGATION_REGISTRY
[
name
]
=
fn
metric_agg_registry
.
register
(
name
)(
fn
)
return
fn
return
decorate
def
get_aggregation
(
name
:
str
)
->
Callable
[[],
dict
[
str
,
Callable
]]:
def
get_aggregation
(
name
:
str
)
->
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]
|
None
:
"""@deprecated Use metric_agg_registry.get() instead."""
try
:
return
AGGREGATION_REGISTRY
[
name
]
# Use the canonical registry
return
metric_agg_registry
.
get
(
name
)
except
KeyError
:
import
logging
...
...
@@ -526,15 +618,17 @@ def freeze_all() -> None: # pragma: no cover
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
MODEL_REGISTRY
:
Mapping
[
str
,
type
[
LM
]]
=
MappingProxyType
(
model_registry
.
_objects
)
# type: ignore[attr-defined]
TASK_REGISTRY
:
Mapping
[
str
,
Callable
[...,
Any
]]
=
MappingProxyType
(
task_registry
.
_objects
)
# type: ignore[attr-defined]
METRIC_REGISTRY
:
Mapping
[
str
,
MetricSpec
]
=
MappingProxyType
(
metric_registry
.
_objects
)
# type: ignore[attr-defined]
METRIC_AGGREGATION_REGISTRY
:
Mapping
[
str
,
Callable
]
=
MappingProxyType
(
metric_agg_registry
.
_objects
)
# type: ignore[attr-defined]
HIGHER_IS_BETTER_REGISTRY
:
Mapping
[
str
,
bool
]
=
MappingProxyType
(
higher_is_better_registry
.
_objects
)
# type: ignore[attr-defined]
FILTER_REGISTRY
:
Mapping
[
str
,
Callable
]
=
MappingProxyType
(
filter_registry
.
_objects
)
# type: ignore[attr-defined]
# These are direct aliases to the registries themselves, which already implement
# the Mapping protocol and provide read-only access to users (since _objects is private).
# This ensures they always reflect the current state of the registries, including
# items registered after module import.
#
# Note: We use type: ignore because Registry doesn't formally inherit from Mapping,
# but it implements all required methods (__getitem__, __iter__, __len__, items)
MODEL_REGISTRY
:
Mapping
[
str
,
LM
]
=
model_registry
# type: ignore[assignment]
TASK_REGISTRY
:
Mapping
[
str
,
Callable
[...,
Any
]]
=
task_registry
# type: ignore[assignment]
METRIC_REGISTRY
:
Mapping
[
str
,
MetricSpec
]
=
metric_registry
# type: ignore[assignment]
METRIC_AGGREGATION_REGISTRY
:
Mapping
[
str
,
Callable
]
=
metric_agg_registry
# type: ignore[assignment]
HIGHER_IS_BETTER_REGISTRY
:
Mapping
[
str
,
bool
]
=
higher_is_better_registry
# type: ignore[assignment]
FILTER_REGISTRY
:
Mapping
[
str
,
Callable
]
=
filter_registry
# type: ignore[assignment]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment