Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9af24b7e
Commit
9af24b7e
authored
Jul 28, 2025
by
Baber
Browse files
refactor registry for simplicity and improved maintainability
parent
907f5f28
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
180 additions
and
465 deletions
+180
-465
lm_eval/api/registry.py
lm_eval/api/registry.py
+178
-463
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+2
-2
No files found.
lm_eval/api/registry.py
View file @
9af24b7e
...
@@ -3,23 +3,16 @@ from __future__ import annotations
...
@@ -3,23 +3,16 @@ from __future__ import annotations
import
importlib
import
importlib
import
inspect
import
inspect
import
threading
import
threading
import
warnings
from
collections.abc
import
Iterable
,
Mapping
from
collections.abc
import
Iterable
,
Mapping
,
MutableMapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
types
import
MappingProxyType
from
types
import
MappingProxyType
from
typing
import
(
from
typing
import
Any
,
Callable
,
Generic
,
Type
,
TypeVar
,
Union
,
cast
Any
,
Callable
,
Generic
,
TypeVar
,
cast
,
)
try
:
# Python≥3.10
try
:
import
importlib.metadata
as
md
import
importlib.metadata
as
md
# Python ≥3.10
except
ImportError
:
# pragma: no cover
-
fallback for 3.8/3.9
runtimes
except
ImportError
:
# pragma: no cover
–
fallback for 3.8/3.9
import
importlib_metadata
as
md
# type: ignore
import
importlib_metadata
as
md
# type: ignore
# Legacy exports (keep for one release, then drop)
# Legacy exports (keep for one release, then drop)
...
@@ -64,6 +57,7 @@ __all__ = [
...
@@ -64,6 +57,7 @@ __all__ = [
]
]
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
Placeholder
=
Union
[
str
,
md
.
EntryPoint
]
# light‑weight lazy token
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
...
@@ -72,533 +66,264 @@ T = TypeVar("T")
...
@@ -72,533 +66,264 @@ T = TypeVar("T")
class
Registry
(
Generic
[
T
]):
class
Registry
(
Generic
[
T
]):
"""Name -> object mapping with decorator helpers and **lazy import** support."""
"""Name → object registry with optional lazy placeholders."""
#: The underlying mutable mapping (might turn into MappingProxy on freeze)
_objects
:
MutableMapping
[
str
,
T
|
str
|
md
.
EntryPoint
]
def
__init__
(
def
__init__
(
self
,
self
,
name
:
str
,
name
:
str
,
*
,
*
,
base_cls
:
type
[
T
]
|
None
=
None
,
base_cls
:
Union
[
Type
[
T
],
None
]
=
None
,
store
:
MutableMapping
[
str
,
T
|
str
|
md
.
EntryPoint
]
|
None
=
None
,
validator
:
Callable
[[
T
],
bool
]
|
None
=
None
,
)
->
None
:
)
->
None
:
self
.
_name
:
str
=
name
self
.
_name
=
name
self
.
_base_cls
:
type
[
T
]
|
None
=
base_cls
self
.
_base_cls
=
base_cls
self
.
_objects
=
store
if
store
is
not
None
else
{}
self
.
_objs
:
dict
[
str
,
Union
[
T
,
Placeholder
]]
=
{}
self
.
_metadata
:
dict
[
self
.
_meta
:
dict
[
str
,
dict
[
str
,
Any
]]
=
{}
str
,
dict
[
str
,
Any
]
]
=
{}
# Store metadata for each registered item
self
.
_validator
=
validator
# Custom validation function
self
.
_lock
=
threading
.
RLock
()
self
.
_lock
=
threading
.
RLock
()
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Registration
helpers
(decorator or direct call)
# Registration (decorator or direct call)
# ------------------------------------------------------------------
# ------------------------------------------------------------------
def
_resolve_aliases
(
self
,
target
:
T
|
str
|
md
.
EntryPoint
,
aliases
:
tuple
[
str
,
...]
)
->
tuple
[
str
,
...]:
"""Resolve aliases for registration."""
if
not
aliases
:
return
(
getattr
(
target
,
"__name__"
,
str
(
target
)),)
return
aliases
def
_check_and_store
(
self
,
alias
:
str
,
target
:
T
|
str
|
md
.
EntryPoint
,
metadata
:
dict
[
str
,
Any
]
|
None
,
)
->
None
:
"""Check constraints and store the target with optional metadata.
Collision policy:
1. If alias doesn't exist → store it
2. If identical value → silently succeed (idempotent)
3. If lazy placeholder + matching concrete class → replace with concrete
4. Otherwise → raise ValueError
Type checking:
- Eager for concrete classes at registration time
- Deferred for lazy placeholders until materialization
"""
with
self
.
_lock
:
# Case 1: New alias
if
alias
not
in
self
.
_objects
:
# Type check concrete classes before storing
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
"
f
"to be registered as a
{
self
.
_name
}
"
)
self
.
_objects
[
alias
]
=
target
if
metadata
:
self
.
_metadata
[
alias
]
=
metadata
return
existing
=
self
.
_objects
[
alias
]
# Case 2: Identical value - idempotent
if
existing
==
target
:
return
# Case 3: Lazy placeholder being replaced by its concrete class
if
isinstance
(
existing
,
str
)
and
isinstance
(
target
,
type
):
mod_path
,
_
,
cls_name
=
existing
.
partition
(
":"
)
if
(
cls_name
and
hasattr
(
target
,
"__module__"
)
and
hasattr
(
target
,
"__name__"
)
):
expected_path
=
f
"
{
target
.
__module__
}
:
{
target
.
__name__
}
"
if
existing
==
expected_path
:
self
.
_objects
[
alias
]
=
target
if
metadata
:
self
.
_metadata
[
alias
]
=
metadata
return
# Case 4: Collision - different values
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
'
{
alias
}
' already registered "
f
"(existing:
{
existing
}
, new:
{
target
}
)"
)
def
register
(
def
register
(
self
,
alias
:
str
,
target
:
T
|
str
|
md
.
EntryPoint
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
T
|
str
|
md
.
EntryPoint
:
"""Register a target (object or lazy placeholder) under the given alias.
Args:
alias: Name to register under
target: Object to register (can be concrete object or lazy string "module:Class")
metadata: Optional metadata to associate with this registration
Returns:
The target that was registered
Examples:
# Direct registration of concrete object
registry.register("mymodel", MyModelClass)
# Lazy registration with module path
registry.register("mymodel", "mypackage.models:MyModelClass")
"""
self
.
_check_and_store
(
alias
,
target
,
metadata
)
return
target
def
decorator
(
self
,
self
,
*
aliases
:
str
,
*
aliases
:
str
,
lazy
:
Union
[
T
,
Placeholder
,
None
]
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
)
->
Callable
[[
T
],
T
]:
"""Create a decorator for registering objects.
"""``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``."""
Args:
def
_store
(
alias
:
str
,
target
:
Union
[
T
,
Placeholder
])
->
None
:
*aliases: Names to register under (if empty, uses object's __name__)
current
=
self
.
_objs
.
get
(
alias
)
metadata: Optional metadata to associate with this registration
# ─── collision handling ────────────────────────────────────
if
current
is
not
None
and
current
!=
target
:
Returns:
# allow placeholder → real object upgrade
Decorator function that registers its target
if
isinstance
(
current
,
str
)
and
isinstance
(
target
,
type
):
mod
,
_
,
cls
=
current
.
partition
(
":"
)
Example:
if
current
==
f
"
{
target
.
__module__
}
:
{
target
.
__name__
}
"
:
@registry.decorator("mymodel", "model-v2")
self
.
_objs
[
alias
]
=
target
class MyModel:
self
.
_meta
[
alias
]
=
metadata
or
{}
pass
return
"""
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
alias '
{
alias
}
' already registered ("
# noqa: B950
def
wrapper
(
obj
:
T
)
->
T
:
f
"existing=
{
current
}
, new=
{
target
}
)"
resolved_aliases
=
aliases
or
(
getattr
(
obj
,
"__name__"
,
str
(
obj
)),)
)
for
alias
in
resolved_aliases
:
# ─── type check for concrete classes ───────────────────────
self
.
register
(
alias
,
obj
,
metadata
)
if
self
.
_base_cls
is
not
None
and
isinstance
(
target
,
type
):
if
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
to be a
{
self
.
_name
}
"
)
self
.
_objs
[
alias
]
=
target
if
metadata
:
self
.
_meta
[
alias
]
=
metadata
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
names
=
aliases
or
(
getattr
(
obj
,
"__name__"
,
str
(
obj
)),)
with
self
.
_lock
:
for
name
in
names
:
_store
(
name
,
obj
)
return
obj
return
obj
return
wrapper
# Direct call with *lazy* placeholder
if
lazy
is
not
None
:
if
len
(
aliases
)
!=
1
:
raise
ValueError
(
"Exactly one alias required when using 'lazy='"
)
with
self
.
_lock
:
_store
(
aliases
[
0
],
lazy
)
# type: ignore[arg-type]
# return no‑op decorator for accidental use
return
lambda
x
:
x
# type: ignore[return-value]
return
decorator
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Lookup & materialisation
# Lookup & materialisation
# ------------------------------------------------------------------
# ------------------------------------------------------------------
@
lru_cache
(
maxsize
=
256
)
# Bounded cache to prevent memory growth
@
lru_cache
(
maxsize
=
256
)
def
_materialise
(
self
,
target
:
T
|
str
|
md
.
EntryPoint
)
->
T
:
def
_materialise
(
self
,
ph
:
Placeholder
)
->
T
:
"""Import *target* if it is a dotted‑path string or EntryPoint."""
if
isinstance
(
ph
,
str
):
if
isinstance
(
target
,
str
):
mod
,
_
,
attr
=
ph
.
partition
(
":"
)
mod
,
_
,
obj_name
=
target
.
partition
(
":"
)
if
not
attr
:
if
not
_
:
raise
ValueError
(
f
"Invalid lazy path '
{
ph
}
', expected 'module:object'"
)
raise
ValueError
(
return
cast
(
T
,
getattr
(
importlib
.
import_module
(
mod
),
attr
))
f
"Lazy path '
{
target
}
' must be in 'module:object' form"
return
cast
(
T
,
ph
.
load
())
)
module
=
importlib
.
import_module
(
mod
)
return
cast
(
T
,
getattr
(
module
,
obj_name
))
if
isinstance
(
target
,
md
.
EntryPoint
):
return
cast
(
T
,
target
.
load
())
return
target
# concrete already
def
get
(
self
,
alias
:
str
)
->
T
:
def
get
(
self
,
alias
:
str
)
->
T
:
# Fast path: check if already materialized without lock
try
:
target
=
self
.
_objects
.
get
(
alias
)
target
=
self
.
_objs
[
alias
]
if
target
is
not
None
and
not
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
except
KeyError
as
exc
:
# Already materialized and validated, return immediately
raise
KeyError
(
return
target
f
"Unknown
{
self
.
_name
}
'
{
alias
}
'. Available:
{
', '
.
join
(
self
.
_objs
)
}
"
)
from
exc
# Slow path: acquire lock for materialization
with
self
.
_lock
:
if
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
try
:
with
self
.
_lock
:
target
=
self
.
_objects
[
alias
]
# Re‑check under lock (another thread might have resolved it)
except
KeyError
as
exc
:
fresh
=
self
.
_objs
[
alias
]
raise
KeyError
(
if
isinstance
(
fresh
,
(
str
,
md
.
EntryPoint
)):
f
"Unknown
{
self
.
_name
}
'
{
alias
}
'. Available: "
concrete
=
self
.
_materialise
(
fresh
)
f
"
{
', '
.
join
(
self
.
_objects
)
}
"
self
.
_objs
[
alias
]
=
concrete
)
from
exc
# Double-check after acquiring a lock (may have been materialized by another thread)
if
not
isinstance
(
target
,
(
str
,
md
.
EntryPoint
)):
return
target
# Materialize the lazy placeholder
concrete
:
T
=
self
.
_materialise
(
target
)
# Swap placeholder with a concrete object (with race condition check)
if
concrete
is
not
target
:
# Final check: another thread might have materialized while we were working
current
=
self
.
_objects
.
get
(
alias
)
if
isinstance
(
current
,
(
str
,
md
.
EntryPoint
)):
# Still a placeholder, safe to replace
self
.
_objects
[
alias
]
=
concrete
else
:
else
:
# Another thread already materialized it, use their result
concrete
=
fresh
# another thread did the job
concrete
=
current
# type: ignore[assignment]
target
=
concrete
# Late type check (for placeholders)
if
self
.
_base_cls
is
not
None
and
not
issubclass
(
concrete
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
TypeError
(
f
"
{
concrete
}
does not inherit from
{
self
.
_base_cls
}
"
f
"(registered under alias '
{
alias
}
')"
)
# Custom validation - run on materialization
# Late type/validator checks
if
self
.
_validator
and
not
self
.
_validator
(
concrete
):
if
self
.
_base_cls
is
not
None
and
not
issubclass
(
target
,
self
.
_base_cls
):
# type: ignore[arg-type]
raise
ValueError
(
raise
TypeError
(
f
"
{
concrete
}
failed custom validation for
{
self
.
_name
}
registry "
f
"
{
target
}
does not inherit from
{
self
.
_base_cls
}
(alias '
{
alias
}
')"
f
"(registered under alias '
{
alias
}
')"
)
)
return
target
return
concrete
# Mapping / dunder helpers -------------------------------------------------
# ------------------------------------------------------------------
# Mapping helpers
# ------------------------------------------------------------------
def
__getitem__
(
self
,
alias
:
str
)
->
T
:
# noqa
def
__getitem__
(
self
,
alias
:
str
)
->
T
:
# noqa
: DunderImplemented
return
self
.
get
(
alias
)
return
self
.
get
(
alias
)
def
__iter__
(
self
):
# noqa
def
__iter__
(
self
):
# noqa
: DunderImplemented
return
iter
(
self
.
_obj
ect
s
)
return
iter
(
self
.
_objs
)
def
__len__
(
self
)
->
int
:
# noqa
def
__len__
(
self
):
# noqa
: DunderImplemented
return
len
(
self
.
_obj
ect
s
)
return
len
(
self
.
_objs
)
def
items
(
self
):
# noqa
def
items
(
self
):
# noqa
: DunderImplemented
return
self
.
_obj
ect
s
.
items
()
return
self
.
_objs
.
items
()
# Introspection -----------------------------------------------------------
# ------------------------------------------------------------------
# Utilities
# ------------------------------------------------------------------
def
origin
(
self
,
alias
:
str
)
->
str
|
None
:
def
metadata
(
self
,
alias
:
str
)
->
Union
[
Mapping
[
str
,
Any
],
None
]:
obj
=
self
.
_objects
.
get
(
alias
)
return
self
.
_meta
.
get
(
alias
)
def
origin
(
self
,
alias
:
str
)
->
Union
[
str
,
None
]:
obj
=
self
.
_objs
.
get
(
alias
)
if
isinstance
(
obj
,
(
str
,
md
.
EntryPoint
)):
return
None
try
:
try
:
if
isinstance
(
obj
,
str
)
or
isinstance
(
obj
,
md
.
EntryPoint
):
path
=
inspect
.
getfile
(
obj
)
# type: ignore[arg-type]
return
None
# placeholder - unknown until imported
file
=
inspect
.
getfile
(
obj
)
# type: ignore[arg-type]
line
=
inspect
.
getsourcelines
(
obj
)[
1
]
# type: ignore[arg-type]
line
=
inspect
.
getsourcelines
(
obj
)[
1
]
# type: ignore[arg-type]
return
f
"
{
file
}
:
{
line
}
"
return
f
"
{
path
}
:
{
line
}
"
except
(
except
Exception
:
# pragma: no cover – best‑effort only
TypeError
,
OSError
,
AttributeError
,
):
# pragma: no cover - best-effort only
# TypeError: object not suitable for inspect
# OSError: file not found or accessible
# AttributeError: object lacks expected attributes
return
None
return
None
def
get_metadata
(
self
,
alias
:
str
)
->
dict
[
str
,
Any
]
|
None
:
"""Get metadata for a registered item."""
with
self
.
_lock
:
return
self
.
_metadata
.
get
(
alias
)
# Mutability --------------------------------------------------------------
def
freeze
(
self
):
def
freeze
(
self
):
"""Make the registry *names* immutable (materialisation still works)."""
with
self
.
_lock
:
with
self
.
_lock
:
if
isinstance
(
self
.
_objects
,
MappingProxyType
):
self
.
_objs
=
MappingProxyType
(
dict
(
self
.
_objs
))
# type: ignore[assignment]
return
# already frozen
self
.
_meta
=
MappingProxyType
(
dict
(
self
.
_meta
))
# type: ignore[assignment]
self
.
_objects
=
MappingProxyType
(
dict
(
self
.
_objects
))
# type: ignore[assignment]
def
clear
(
self
):
# Test helper -------------------------------------------------------------
"""Clear the registry (useful for tests). Cannot be called on frozen registries."""
with
self
.
_lock
:
def
_clear
(
self
):
# pragma: no cover
if
isinstance
(
self
.
_objects
,
MappingProxyType
):
"""Erase registry (for isolated tests)."""
raise
RuntimeError
(
"Cannot clear a frozen registry"
)
self
.
_objs
.
clear
()
self
.
_objects
.
clear
()
self
.
_meta
.
clear
()
self
.
_metadata
.
clear
()
self
.
_materialise
.
cache_clear
()
self
.
_materialise
.
cache_clear
()
# type: ignore[attr-defined]
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
# Structured object
s stored in regis
tri
e
s
# Structured object
for me
tri
c
s
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
MetricSpec
:
class
MetricSpec
:
"""Bundle compute fn, aggregator, and *higher‑is‑better* flag."""
compute
:
Callable
[[
Any
,
Any
],
Any
]
compute
:
Callable
[[
Any
,
Any
],
Any
]
aggregate
:
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]
]
aggregate
:
Callable
[[
Iterable
[
Any
]],
float
]
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
output_type
:
str
|
None
=
None
# e.g., "probability", "string", "numeric"
output_type
:
Union
[
str
,
None
]
=
None
requires
:
list
[
str
]
|
None
=
None
# Dependencies on other metrics/data
requires
:
Union
[
list
[
str
]
,
None
]
=
None
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
# C
oncrete
registries
used by lm_eval
# C
anonical
registries
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
from
lm_eval.api.model
import
LM
# noqa: E402
from
lm_eval.api.model
import
LM
# noqa: E402
model_registry
:
Registry
[
LM
]
=
Registry
(
"model"
,
base_cls
=
LM
)
model_registry
:
Registry
[
type
[
LM
]
]
=
Registry
(
"model"
,
base_cls
=
LM
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]
]
=
(
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
float
]]
=
Registry
(
Registry
(
"metric aggregation"
)
"metric aggregation"
)
)
higher_is_better_registry
:
Registry
[
bool
]
=
Registry
(
"higher‑is‑better flag"
)
higher_is_better_registry
:
Registry
[
bool
]
=
Registry
(
"higher‑is‑better flag"
)
filter_registry
:
Registry
[
Callable
]
=
Registry
(
"filter"
)
filter_registry
:
Registry
[
Callable
]
=
Registry
(
"filter"
)
# Default metric registry for output types
# Public helper aliases ------------------------------------------------------
DEFAULT_METRIC_REGISTRY
=
{
"loglikelihood"
:
[
"perplexity"
,
"acc"
,
],
"loglikelihood_rolling"
:
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
],
"multiple_choice"
:
[
"acc"
,
"acc_norm"
],
"generate_until"
:
[
"exact_match"
],
}
def
default_metrics_for
(
output_type
:
str
)
->
list
[
str
]:
"""Get default metrics for a given output type dynamically.
This walks the metric registry to find metrics that match the output type.
Falls back to DEFAULT_METRIC_REGISTRY if no dynamic matches found.
"""
# First, check static defaults
if
output_type
in
DEFAULT_METRIC_REGISTRY
:
return
DEFAULT_METRIC_REGISTRY
[
output_type
]
# Walk metric registry for matching output types
register_model
=
model_registry
.
register
matching_metrics
=
[]
for
name
,
metric_spec
in
metric_registry
.
items
():
if
(
isinstance
(
metric_spec
,
MetricSpec
)
and
metric_spec
.
output_type
==
output_type
):
matching_metrics
.
append
(
name
)
return
matching_metrics
if
matching_metrics
else
[]
# Aggregation registry - alias to the canonical registry for backward compatibility
AGGREGATION_REGISTRY
=
metric_agg_registry
# The registry itself is dict-like
# ────────────────────────────────────────────────────────────────────────
# Public helper aliases (legacy API)
# ────────────────────────────────────────────────────────────────────────
register_model
=
model_registry
.
decorator
get_model
=
model_registry
.
get
get_model
=
model_registry
.
get
register_task
=
task_registry
.
decorato
r
register_task
=
task_registry
.
registe
r
get_task
=
task_registry
.
get
get_task
=
task_registry
.
get
register_filter
=
filter_registry
.
register
get_filter
=
filter_registry
.
get
# Metric helpers need thin wrappers to build MetricSpec ----------------------
# Special handling for metric registration which uses different API
def
register_metric
(
**
kwargs
):
"""Register a metric with metadata.
Compatible with old registry API that used keyword arguments.
"""
def
decorate
(
fn
):
metric_name
=
kwargs
.
get
(
"metric"
)
if
not
metric_name
:
raise
ValueError
(
"metric name is required"
)
# Determine aggregation function
aggregate_fn
:
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]
|
None
=
None
if
"aggregation"
in
kwargs
:
agg_name
=
kwargs
[
"aggregation"
]
try
:
aggregate_fn
=
metric_agg_registry
.
get
(
agg_name
)
except
KeyError
:
raise
ValueError
(
f
"Unknown aggregation:
{
agg_name
}
"
)
else
:
# No aggregation specified - use a function that raises NotImplementedError
def
not_implemented_agg
(
values
):
raise
NotImplementedError
(
f
"No aggregation function specified for metric '
{
metric_name
}
'. "
"Please specify an 'aggregation' parameter."
)
aggregate_fn
=
not_implemented_agg
def
register_metric
(
**
kw
):
name
=
kw
[
"metric"
]
# Create MetricSpec with the function and metadata
def
deco
(
fn
):
spec
=
MetricSpec
(
spec
=
MetricSpec
(
compute
=
fn
,
compute
=
fn
,
aggregate
=
aggregate_fn
,
aggregate
=
(
higher_is_better
=
kwargs
.
get
(
"higher_is_better"
,
True
),
metric_agg_registry
.
get
(
kw
[
"aggregation"
])
output_type
=
kwargs
.
get
(
"output_type"
),
if
"aggregation"
in
kw
requires
=
kwargs
.
get
(
"requires"
),
else
lambda
_
:
{}
),
higher_is_better
=
kw
.
get
(
"higher_is_better"
,
True
),
output_type
=
kw
.
get
(
"output_type"
),
requires
=
kw
.
get
(
"requires"
),
)
)
metric_registry
.
register
(
name
,
lazy
=
spec
,
metadata
=
kw
)
# Use a proper registry API with metadata
higher_is_better_registry
.
register
(
name
,
lazy
=
spec
.
higher_is_better
)
metric_registry
.
register
(
metric_name
,
spec
,
metadata
=
kwargs
)
# Also register in higher_is_better registry if specified
if
"higher_is_better"
in
kwargs
:
higher_is_better_registry
.
register
(
metric_name
,
kwargs
[
"higher_is_better"
])
return
fn
return
fn
return
deco
rate
return
deco
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
):
def
get_metric
(
name
,
hf_evaluate_metric
=
False
):
"""Get a metric by name, with fallback to HF evaluate."""
try
:
if
not
hf_evaluate_metric
:
spec
=
metric_registry
.
get
(
name
)
try
:
return
spec
.
compute
# type: ignore[attr-defined]
spec
=
metric_registry
.
get
(
name
)
except
KeyError
:
if
isinstance
(
spec
,
MetricSpec
):
if
not
hf_evaluate_metric
:
return
spec
.
compute
return
spec
except
KeyError
:
import
logging
import
logging
logging
.
getLogger
(
__name__
).
warning
(
logging
.
getLogger
(
__name__
).
warning
(
f
"
Could not find registered metric '
{
name
}
' in lm-eval, search
ing
in
HF
E
valuate
library...
"
f
"
Metric '
{
name
}
' not in registry; try
ing HF
e
valuate
…
"
)
)
try
:
import
evaluate
as
hf
# Fallback to HF evaluate
return
hf
.
load
(
name
).
compute
# type: ignore[attr-defined]
try
:
except
Exception
:
import
evaluate
as
hf_evaluate
raise
KeyError
(
f
"Metric '
{
name
}
' not found anywhere"
)
metric_object
=
hf_evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
import
logging
logging
.
getLogger
(
__name__
).
error
(
f
"
{
name
}
not found in the evaluate library! Please check https://huggingface.co/evaluate-metric"
,
)
return
None
register_metric_aggregation
=
metric_agg_registry
.
decorator
def
get_metric_aggregation
(
metric_name
:
str
,
)
->
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]:
"""Get the aggregation function for a metric."""
# First, try to get from the metric registry (for metrics registered with aggregation)
try
:
metric_spec
=
metric_registry
.
get
(
metric_name
)
if
isinstance
(
metric_spec
,
MetricSpec
)
and
metric_spec
.
aggregate
:
return
metric_spec
.
aggregate
except
KeyError
:
pass
# Try the next registry
# Fall back to metric_agg_registry (for standalone aggregations)
try
:
return
metric_agg_registry
.
get
(
metric_name
)
except
KeyError
:
pass
# If not found, raise an error
raise
KeyError
(
f
"Unknown metric aggregation '
{
metric_name
}
'. Available:
{
list
(
metric_agg_registry
)
}
"
)
register_metric_aggregation
=
metric_agg_registry
.
register
get_metric_aggregation
=
metric_agg_registry
.
get
register_higher_is_better
=
higher_is_better_registry
.
decorato
r
register_higher_is_better
=
higher_is_better_registry
.
registe
r
is_higher_better
=
higher_is_better_registry
.
get
is_higher_better
=
higher_is_better_registry
.
get
register_filter
=
filter_registry
.
decorator
# Legacy compatibility
get_filter
=
filter_registry
.
get
register_aggregation
=
metric_agg_registry
.
register
get_aggregation
=
metric_agg_registry
.
get
DEFAULT_METRIC_REGISTRY
=
metric_registry
# Special handling for AGGREGATION_REGISTRY which works differently
AGGREGATION_REGISTRY
=
metric_agg_registry
def
register_aggregation
(
name
:
str
):
"""@deprecated Use metric_agg_registry.register() instead."""
warnings
.
warn
(
"register_aggregation() is deprecated. Use metric_agg_registry.register() instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
def
decorate
(
fn
):
# Use the canonical registry as a single source of truth
if
name
in
metric_agg_registry
:
raise
ValueError
(
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
)
metric_agg_registry
.
register
(
name
,
fn
)
return
fn
return
decorate
def
get_aggregation
(
name
:
str
)
->
Callable
[[
Iterable
[
Any
]],
Mapping
[
str
,
float
]]
|
None
:
"""@deprecated Use metric_agg_registry.get() instead."""
try
:
# Use the canonical registry
return
metric_agg_registry
.
get
(
name
)
except
KeyError
:
import
logging
logging
.
getLogger
(
__name__
).
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
return
None
# ────────────────────────────────────────────────────────────────────────
# Optional PyPI entry‑point discovery - uncomment if desired
# ────────────────────────────────────────────────────────────────────────
# for _group, _reg in {
# Convenience ----------------------------------------------------------------
# "lm_eval.models": model_registry,
# "lm_eval.tasks": task_registry,
# "lm_eval.metrics": metric_registry,
# }.items():
# for _ep in md.entry_points(group=_group):
# _reg.register(_ep.name, lazy=_ep)
# ────────────────────────────────────────────────────────────────────────
def
freeze_all
():
# Convenience
for
r
in
(
# ────────────────────────────────────────────────────────────────────────
def
freeze_all
()
->
None
:
# pragma: no cover
"""Freeze every global registry (idempotent)."""
for
_reg
in
(
model_registry
,
model_registry
,
task_registry
,
task_registry
,
metric_registry
,
metric_registry
,
...
@@ -606,24 +331,14 @@ def freeze_all() -> None: # pragma: no cover
...
@@ -606,24 +331,14 @@ def freeze_all() -> None: # pragma: no cover
higher_is_better_registry
,
higher_is_better_registry
,
filter_registry
,
filter_registry
,
):
):
_reg
.
freeze
()
r
.
freeze
()
# ────────────────────────────────────────────────────────────────────────
# Backwards‑compat read‑only aliases ----------------------------------------
# Backwards‑compatibility read‑only globals
# ────────────────────────────────────────────────────────────────────────
# These are direct aliases to the registries themselves, which already implement
MODEL_REGISTRY
=
model_registry
# type: ignore
# the Mapping protocol and provide read-only access to users (since _objects is private).
TASK_REGISTRY
=
task_registry
# type: ignore
# This ensures they always reflect the current state of the registries, including
METRIC_REGISTRY
=
metric_registry
# type: ignore
# items registered after module import.
METRIC_AGGREGATION_REGISTRY
=
metric_agg_registry
# type: ignore
#
HIGHER_IS_BETTER_REGISTRY
=
higher_is_better_registry
# type: ignore
# Note: We use type: ignore because Registry doesn't formally inherit from Mapping,
FILTER_REGISTRY
=
filter_registry
# type: ignore
# but it implements all required methods (__getitem__, __iter__, __len__, items)
MODEL_REGISTRY
:
Mapping
[
str
,
LM
]
=
model_registry
# type: ignore[assignment]
TASK_REGISTRY
:
Mapping
[
str
,
Callable
[...,
Any
]]
=
task_registry
# type: ignore[assignment]
METRIC_REGISTRY
:
Mapping
[
str
,
MetricSpec
]
=
metric_registry
# type: ignore[assignment]
METRIC_AGGREGATION_REGISTRY
:
Mapping
[
str
,
Callable
]
=
metric_agg_registry
# type: ignore[assignment]
HIGHER_IS_BETTER_REGISTRY
:
Mapping
[
str
,
bool
]
=
higher_is_better_registry
# type: ignore[assignment]
FILTER_REGISTRY
:
Mapping
[
str
,
Callable
]
=
filter_registry
# type: ignore[assignment]
lm_eval/models/__init__.py
View file @
9af24b7e
...
@@ -41,8 +41,8 @@ def _register_all_models():
...
@@ -41,8 +41,8 @@ def _register_all_models():
for
name
,
path
in
MODEL_MAPPING
.
items
():
for
name
,
path
in
MODEL_MAPPING
.
items
():
# Only register if not already present (avoids conflicts when modules are imported)
# Only register if not already present (avoids conflicts when modules are imported)
if
name
not
in
model_registry
:
if
name
not
in
model_registry
:
# Register the lazy placeholder
directly
# Register the lazy placeholder
using lazy parameter
model_registry
.
register
(
name
,
path
)
model_registry
.
register
(
name
,
lazy
=
path
)
# Call registration on module import
# Call registration on module import
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment