Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
124d3049
Commit
124d3049
authored
Jul 28, 2025
by
Baber
Browse files
better placeholder materialization
parent
9af24b7e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
42 deletions
+65
-42
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+2
-2
lm_eval/api/registry.py
lm_eval/api/registry.py
+63
-40
No files found.
lm_eval/api/metrics.py
View file @
124d3049
...
@@ -5,7 +5,7 @@ import random
...
@@ -5,7 +5,7 @@ import random
import
re
import
re
import
string
import
string
from
collections.abc
import
Iterable
,
Sequence
from
collections.abc
import
Iterable
,
Sequence
from
typing
import
Callable
,
List
,
Optional
,
TypeVar
from
typing
import
Callable
,
Generic
,
List
,
Optional
,
TypeVar
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
import
sacrebleu
...
@@ -451,7 +451,7 @@ def _sacreformat(refs, preds):
...
@@ -451,7 +451,7 @@ def _sacreformat(refs, preds):
# stderr stuff
# stderr stuff
class
_bootstrap_internal
:
class
_bootstrap_internal
(
Generic
[
T
])
:
"""
"""
Pool worker: `(i, xs)` → `n` bootstrap replicates
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
of `f(xs)`using a RNG seeded with `i`.
...
...
lm_eval/api/registry.py
View file @
124d3049
...
@@ -3,11 +3,11 @@ from __future__ import annotations
...
@@ -3,11 +3,11 @@ from __future__ import annotations
import
importlib
import
importlib
import
inspect
import
inspect
import
threading
import
threading
from
collections.abc
import
Iterable
,
Mapping
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
types
import
MappingProxyType
from
types
import
MappingProxyType
from
typing
import
Any
,
Callable
,
Generic
,
Type
,
TypeVar
,
Union
,
cast
from
typing
import
Any
,
Callable
,
Generic
,
TypeVar
,
Union
,
cast
try
:
try
:
...
@@ -15,7 +15,6 @@ try:
...
@@ -15,7 +15,6 @@ try:
except
ImportError
:
# pragma: no cover – fallback for 3.8/3.9
except
ImportError
:
# pragma: no cover – fallback for 3.8/3.9
import
importlib_metadata
as
md
# type: ignore
import
importlib_metadata
as
md
# type: ignore
# Legacy exports (keep for one release, then drop)
LEGACY_EXPORTS
=
[
LEGACY_EXPORTS
=
[
"DEFAULT_METRIC_REGISTRY"
,
"DEFAULT_METRIC_REGISTRY"
,
"AGGREGATION_REGISTRY"
,
"AGGREGATION_REGISTRY"
,
...
@@ -52,14 +51,40 @@ __all__ = [
...
@@ -52,14 +51,40 @@ __all__ = [
"higher_is_better_registry"
,
"higher_is_better_registry"
,
"filter_registry"
,
"filter_registry"
,
"freeze_all"
,
"freeze_all"
,
# legacy
*
LEGACY_EXPORTS
,
*
LEGACY_EXPORTS
,
]
]
# type: ignore
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
Placeholder
=
Union
[
str
,
md
.
EntryPoint
]
# light‑weight lazy token
Placeholder
=
Union
[
str
,
md
.
EntryPoint
]
# light‑weight lazy token
# ────────────────────────────────────────────────────────────────────────
# Module-level cache for materializing placeholders (prevents memory leak)
# ────────────────────────────────────────────────────────────────────────
@
lru_cache
(
maxsize
=
16
)
def
_materialise_placeholder
(
ph
:
Placeholder
)
->
Any
:
"""Materialize a lazy placeholder into the actual object.
This is at module level to avoid memory leaks from lru_cache on instance methods.
"""
if
isinstance
(
ph
,
str
):
mod
,
_
,
attr
=
ph
.
partition
(
":"
)
if
not
attr
:
raise
ValueError
(
f
"Invalid lazy path '
{
ph
}
', expected 'module:object'"
)
return
getattr
(
importlib
.
import_module
(
mod
),
attr
)
return
ph
.
load
()
# ────────────────────────────────────────────────────────────────────────
# Metric-specific metadata storage
# ────────────────────────────────────────────────────────────────────────
_metric_meta
:
dict
[
str
,
dict
[
str
,
Any
]]
=
{}
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
# Generic Registry
# Generic Registry
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
...
@@ -72,12 +97,11 @@ class Registry(Generic[T]):
...
@@ -72,12 +97,11 @@ class Registry(Generic[T]):
self
,
self
,
name
:
str
,
name
:
str
,
*
,
*
,
base_cls
:
Union
[
T
ype
[
T
]
,
None
]
=
None
,
base_cls
:
t
ype
[
T
]
|
None
=
None
,
)
->
None
:
)
->
None
:
self
.
_name
=
name
self
.
_name
=
name
self
.
_base_cls
=
base_cls
self
.
_base_cls
=
base_cls
self
.
_objs
:
dict
[
str
,
Union
[
T
,
Placeholder
]]
=
{}
self
.
_objs
:
dict
[
str
,
T
|
Placeholder
]
=
{}
self
.
_meta
:
dict
[
str
,
dict
[
str
,
Any
]]
=
{}
self
.
_lock
=
threading
.
RLock
()
self
.
_lock
=
threading
.
RLock
()
# ------------------------------------------------------------------
# ------------------------------------------------------------------
...
@@ -87,24 +111,22 @@ class Registry(Generic[T]):
...
@@ -87,24 +111,22 @@ class Registry(Generic[T]):
def
register
(
def
register
(
self
,
self
,
*
aliases
:
str
,
*
aliases
:
str
,
lazy
:
Union
[
T
,
Placeholder
,
None
]
=
None
,
lazy
:
T
|
Placeholder
|
None
=
None
,
metadata
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Callable
[[
T
],
T
]:
)
->
Callable
[[
T
],
T
]:
"""``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``."""
"""``@reg.register('foo')`` or ``reg.register('foo', lazy='pkg.mod:Obj')``."""
def
_store
(
alias
:
str
,
target
:
Union
[
T
,
Placeholder
]
)
->
None
:
def
_store
(
alias
:
str
,
target
:
T
|
Placeholder
)
->
None
:
current
=
self
.
_objs
.
get
(
alias
)
current
=
self
.
_objs
.
get
(
alias
)
# ─── collision handling ────────────────────────────────────
# ─── collision handling ────────────────────────────────────
if
current
is
not
None
and
current
!=
target
:
if
current
is
not
None
and
current
!=
target
:
# allow placeholder → real object upgrade
# allow placeholder → real object upgrade
if
isinstance
(
current
,
str
)
and
isinstance
(
target
,
type
):
if
isinstance
(
current
,
str
)
and
isinstance
(
target
,
type
):
mod
,
_
,
cls
=
current
.
partition
(
":"
)
#
mod, _, cls = current.partition(":")
if
current
==
f
"
{
target
.
__module__
}
:
{
target
.
__name__
}
"
:
if
current
==
f
"
{
target
.
__module__
}
:
{
target
.
__name__
}
"
:
self
.
_objs
[
alias
]
=
target
self
.
_objs
[
alias
]
=
target
self
.
_meta
[
alias
]
=
metadata
or
{}
return
return
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
_name
!
r
}
alias '
{
alias
}
' already registered ("
# noqa: B950
f
"
{
self
.
_name
!
r
}
alias '
{
alias
}
' already registered ("
f
"existing=
{
current
}
, new=
{
target
}
)"
f
"existing=
{
current
}
, new=
{
target
}
)"
)
)
# ─── type check for concrete classes ───────────────────────
# ─── type check for concrete classes ───────────────────────
...
@@ -114,8 +136,6 @@ class Registry(Generic[T]):
...
@@ -114,8 +136,6 @@ class Registry(Generic[T]):
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
to be a
{
self
.
_name
}
"
f
"
{
target
}
must inherit from
{
self
.
_base_cls
}
to be a
{
self
.
_name
}
"
)
)
self
.
_objs
[
alias
]
=
target
self
.
_objs
[
alias
]
=
target
if
metadata
:
self
.
_meta
[
alias
]
=
metadata
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
def
decorator
(
obj
:
T
)
->
T
:
# type: ignore[valid-type]
names
=
aliases
or
(
getattr
(
obj
,
"__name__"
,
str
(
obj
)),)
names
=
aliases
or
(
getattr
(
obj
,
"__name__"
,
str
(
obj
)),)
...
@@ -139,14 +159,9 @@ class Registry(Generic[T]):
...
@@ -139,14 +159,9 @@ class Registry(Generic[T]):
# Lookup & materialisation
# Lookup & materialisation
# ------------------------------------------------------------------
# ------------------------------------------------------------------
@
lru_cache
(
maxsize
=
256
)
def
_materialise
(
self
,
ph
:
Placeholder
)
->
T
:
def
_materialise
(
self
,
ph
:
Placeholder
)
->
T
:
if
isinstance
(
ph
,
str
):
"""Materialize a placeholder using the module-level cached function."""
mod
,
_
,
attr
=
ph
.
partition
(
":"
)
return
cast
(
T
,
_materialise_placeholder
(
ph
))
if
not
attr
:
raise
ValueError
(
f
"Invalid lazy path '
{
ph
}
', expected 'module:object'"
)
return
cast
(
T
,
getattr
(
importlib
.
import_module
(
mod
),
attr
))
return
cast
(
T
,
ph
.
load
())
def
get
(
self
,
alias
:
str
)
->
T
:
def
get
(
self
,
alias
:
str
)
->
T
:
try
:
try
:
...
@@ -162,7 +177,9 @@ class Registry(Generic[T]):
...
@@ -162,7 +177,9 @@ class Registry(Generic[T]):
fresh
=
self
.
_objs
[
alias
]
fresh
=
self
.
_objs
[
alias
]
if
isinstance
(
fresh
,
(
str
,
md
.
EntryPoint
)):
if
isinstance
(
fresh
,
(
str
,
md
.
EntryPoint
)):
concrete
=
self
.
_materialise
(
fresh
)
concrete
=
self
.
_materialise
(
fresh
)
self
.
_objs
[
alias
]
=
concrete
# Only update if not frozen (MappingProxyType)
if
not
isinstance
(
self
.
_objs
,
MappingProxyType
):
self
.
_objs
[
alias
]
=
concrete
else
:
else
:
concrete
=
fresh
# another thread did the job
concrete
=
fresh
# another thread did the job
target
=
concrete
target
=
concrete
...
@@ -178,26 +195,23 @@ class Registry(Generic[T]):
...
@@ -178,26 +195,23 @@ class Registry(Generic[T]):
# Mapping helpers
# Mapping helpers
# ------------------------------------------------------------------
# ------------------------------------------------------------------
def
__getitem__
(
self
,
alias
:
str
)
->
T
:
# noqa: DunderImplemented
def
__getitem__
(
self
,
alias
:
str
)
->
T
:
return
self
.
get
(
alias
)
return
self
.
get
(
alias
)
def
__iter__
(
self
):
# noqa: DunderImplemented
def
__iter__
(
self
):
return
iter
(
self
.
_objs
)
return
iter
(
self
.
_objs
)
def
__len__
(
self
):
# noqa: DunderImplemented
def
__len__
(
self
):
return
len
(
self
.
_objs
)
return
len
(
self
.
_objs
)
def
items
(
self
):
# noqa: DunderImplemented
def
items
(
self
):
return
self
.
_objs
.
items
()
return
self
.
_objs
.
items
()
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Utilities
# Utilities
# ------------------------------------------------------------------
# ------------------------------------------------------------------
def
metadata
(
self
,
alias
:
str
)
->
Union
[
Mapping
[
str
,
Any
],
None
]:
def
origin
(
self
,
alias
:
str
)
->
str
|
None
:
return
self
.
_meta
.
get
(
alias
)
def
origin
(
self
,
alias
:
str
)
->
Union
[
str
,
None
]:
obj
=
self
.
_objs
.
get
(
alias
)
obj
=
self
.
_objs
.
get
(
alias
)
if
isinstance
(
obj
,
(
str
,
md
.
EntryPoint
)):
if
isinstance
(
obj
,
(
str
,
md
.
EntryPoint
)):
return
None
return
None
...
@@ -211,15 +225,13 @@ class Registry(Generic[T]):
...
@@ -211,15 +225,13 @@ class Registry(Generic[T]):
def
freeze
(
self
):
def
freeze
(
self
):
with
self
.
_lock
:
with
self
.
_lock
:
self
.
_objs
=
MappingProxyType
(
dict
(
self
.
_objs
))
# type: ignore[assignment]
self
.
_objs
=
MappingProxyType
(
dict
(
self
.
_objs
))
# type: ignore[assignment]
self
.
_meta
=
MappingProxyType
(
dict
(
self
.
_meta
))
# type: ignore[assignment]
# Test helper -------------------------------------------------------------
# Test helper -------------------------------------------------------------
def
_clear
(
self
):
# pragma: no cover
def
_clear
(
self
):
# pragma: no cover
"""Erase registry (for isolated tests)."""
"""Erase registry (for isolated tests)."""
self
.
_objs
.
clear
()
self
.
_objs
.
clear
()
self
.
_meta
.
clear
()
_materialise_placeholder
.
cache_clear
()
self
.
_materialise
.
cache_clear
()
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
...
@@ -232,8 +244,8 @@ class MetricSpec:
...
@@ -232,8 +244,8 @@ class MetricSpec:
compute
:
Callable
[[
Any
,
Any
],
Any
]
compute
:
Callable
[[
Any
,
Any
],
Any
]
aggregate
:
Callable
[[
Iterable
[
Any
]],
float
]
aggregate
:
Callable
[[
Iterable
[
Any
]],
float
]
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
output_type
:
Union
[
str
,
None
]
=
None
output_type
:
str
|
None
=
None
requires
:
Union
[
list
[
str
]
,
None
]
=
None
requires
:
list
[
str
]
|
None
=
None
# ────────────────────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────
...
@@ -243,7 +255,9 @@ class MetricSpec:
...
@@ -243,7 +255,9 @@ class MetricSpec:
from
lm_eval.api.model
import
LM
# noqa: E402
from
lm_eval.api.model
import
LM
# noqa: E402
model_registry
:
Registry
[
type
[
LM
]]
=
Registry
(
"model"
,
base_cls
=
LM
)
model_registry
:
Registry
[
type
[
LM
]]
=
cast
(
Registry
[
type
[
LM
]],
Registry
(
"model"
,
base_cls
=
LM
)
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
task_registry
:
Registry
[
Callable
[...,
Any
]]
=
Registry
(
"task"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_registry
:
Registry
[
MetricSpec
]
=
Registry
(
"metric"
)
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
float
]]
=
Registry
(
metric_agg_registry
:
Registry
[
Callable
[[
Iterable
[
Any
]],
float
]]
=
Registry
(
...
@@ -266,6 +280,14 @@ get_filter = filter_registry.get
...
@@ -266,6 +280,14 @@ get_filter = filter_registry.get
# Metric helpers need thin wrappers to build MetricSpec ----------------------
# Metric helpers need thin wrappers to build MetricSpec ----------------------
def
_no_aggregation_fn
(
values
:
Iterable
[
Any
])
->
float
:
"""Default aggregation that raises NotImplementedError."""
raise
NotImplementedError
(
"No aggregation function specified for this metric. "
"Please specify 'aggregation' parameter in @register_metric."
)
def
register_metric
(
**
kw
):
def
register_metric
(
**
kw
):
name
=
kw
[
"metric"
]
name
=
kw
[
"metric"
]
...
@@ -275,13 +297,14 @@ def register_metric(**kw):
...
@@ -275,13 +297,14 @@ def register_metric(**kw):
aggregate
=
(
aggregate
=
(
metric_agg_registry
.
get
(
kw
[
"aggregation"
])
metric_agg_registry
.
get
(
kw
[
"aggregation"
])
if
"aggregation"
in
kw
if
"aggregation"
in
kw
else
lambda
_
:
{}
else
_no_aggregation_fn
),
),
higher_is_better
=
kw
.
get
(
"higher_is_better"
,
True
),
higher_is_better
=
kw
.
get
(
"higher_is_better"
,
True
),
output_type
=
kw
.
get
(
"output_type"
),
output_type
=
kw
.
get
(
"output_type"
),
requires
=
kw
.
get
(
"requires"
),
requires
=
kw
.
get
(
"requires"
),
)
)
metric_registry
.
register
(
name
,
lazy
=
spec
,
metadata
=
kw
)
metric_registry
.
register
(
name
,
lazy
=
spec
)
_metric_meta
[
name
]
=
kw
higher_is_better_registry
.
register
(
name
,
lazy
=
spec
.
higher_is_better
)
higher_is_better_registry
.
register
(
name
,
lazy
=
spec
.
higher_is_better
)
return
fn
return
fn
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment