Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
28c78d30
Commit
28c78d30
authored
Jun 30, 2025
by
Baber
Browse files
add MetricConfig
parent
de496b80
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
223 additions
and
156 deletions
+223
-156
lm_eval/__main__.py
lm_eval/__main__.py
+4
-0
lm_eval/api/group.py
lm_eval/api/group.py
+1
-2
lm_eval/api/instance.py
lm_eval/api/instance.py
+17
-4
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+6
-1
lm_eval/api/registry.py
lm_eval/api/registry.py
+13
-10
lm_eval/api/samplers.py
lm_eval/api/samplers.py
+3
-3
lm_eval/api/task.py
lm_eval/api/task.py
+174
-131
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-1
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+2
-2
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+2
-2
No files found.
lm_eval/__main__.py
View file @
28c78d30
...
...
@@ -485,6 +485,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if
results
is
not
None
:
if
args
.
log_samples
:
samples
=
results
.
pop
(
"samples"
)
# TODO: fix this!
results
[
"higher_is_better"
]
=
{
k
:
True
for
k
,
v
in
results
[
"higher_is_better"
].
items
()
}
dumped
=
json
.
dumps
(
results
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
)
...
...
lm_eval/api/group.py
View file @
28c78d30
import
abc
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
...
...
@@ -84,7 +83,7 @@ class GroupConfig(dict):
return
str
(
value
)
class
ConfigurableGroup
(
abc
.
ABC
)
:
class
ConfigurableGroup
:
def
__init__
(
self
,
config
:
Optional
[
dict
]
=
None
,
...
...
lm_eval/api/instance.py
View file @
28c78d30
...
...
@@ -14,10 +14,23 @@ class Instance:
arguments
:
tuple
idx
:
int
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
default_factory
=
lambda
:
(
None
,
None
,
None
)
default_factory
=
lambda
:
(
None
,
None
,
None
),
metadata
=
dict
(
description
=
"Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps
:
list
=
field
(
default_factory
=
list
,
metadata
=
dict
(
description
=
"List of responses from the model for this instance."
),
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
,
metadata
=
dict
(
description
=
"List of filtered responses for this instance, keyed by filter name."
),
)
resps
:
list
=
field
(
default_factory
=
list
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
# initialized after init
task_name
:
Optional
[
str
]
=
None
...
...
@@ -29,7 +42,7 @@ class Instance:
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
@
property
def
args
(
self
):
def
args
(
self
)
->
tuple
:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
...
...
lm_eval/api/metrics.py
View file @
28c78d30
...
...
@@ -8,7 +8,6 @@ from collections.abc import Iterable
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
import
numpy
as
np
import
sacrebleu
from
lm_eval.api.registry
import
register_aggregation
,
register_metric
...
...
@@ -92,6 +91,8 @@ def bleu(items):
Higher is better
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
...
@@ -107,6 +108,8 @@ def chrf(items):
Higher is better # TODO I think
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
...
@@ -123,6 +126,8 @@ def ter(items):
Lower is better
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
...
lm_eval/api/registry.py
View file @
28c78d30
import
logging
from
typing
import
Callable
,
Dict
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Union
import
evaluate
as
hf_evaluate
from
lm_eval.api.model
import
LM
if
TYPE_CHECKING
:
from
lm_eval.api.model
import
LM
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def
register_model
(
*
names
):
from
lm_eval.api.model
import
LM
# either pass a list or a single alias.
# function receives them as a tuple of strings
...
...
@@ -31,7 +32,7 @@ def register_model(*names):
return
decorate
def
get_model
(
model_name
)
:
def
get_model
(
model_name
:
str
)
->
type
[
"LM"
]
:
try
:
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
:
...
...
@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index
=
{}
def
register_task
(
name
):
def
register_task
(
name
:
str
):
def
decorate
(
fn
):
assert
name
not
in
TASK_REGISTRY
,
(
f
"task named '
{
name
}
' conflicts with existing registered task!"
...
...
@@ -120,7 +121,7 @@ def register_metric(**args):
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Optional
[
Callable
]
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
...
...
@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
)
try
:
import
evaluate
as
hf_evaluate
metric_object
=
hf_evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
...
...
@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return
decorate
def
get_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
def
get_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]
]
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
def
get_metric_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]
]
:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
)
->
bool
:
def
is_higher_better
(
metric_name
)
->
Optional
[
bool
]
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
...
...
lm_eval/api/samplers.py
View file @
28c78d30
import
logging
import
warnings
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Sequence
,
Union
import
datasets
...
...
@@ -181,7 +181,7 @@ class ContextSampler:
return
chat_history
def
sample
(
self
,
n
:
int
):
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]
:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
...
...
@@ -190,7 +190,7 @@ class ContextSampler:
class
FirstNSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]
:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...
...
lm_eval/api/task.py
View file @
28c78d30
This diff is collapsed.
Click to expand it.
lm_eval/evaluator.py
View file @
28c78d30
...
...
@@ -287,7 +287,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def
_adjust_config
(
task_dict
)
:
def
_adjust_config
(
task_dict
:
dict
[
str
,
"Task"
])
->
dict
[
str
,
"Task"
]
:
adjusted_task_dict
=
{}
for
task_name
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
task_obj
,
dict
):
...
...
lm_eval/evaluator_utils.py
View file @
28c78d30
...
...
@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr
,
stderr_for_metric
,
)
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
ConfigurableTask
,
Task
from
lm_eval.utils
import
positional_deprecated
...
...
@@ -58,7 +58,7 @@ class TaskOutput:
group_alias
=
None
,
is_group
=
None
,
):
self
.
task
=
task
self
.
task
:
Union
[
Task
,
ConfigurableTask
]
=
task
self
.
task_config
=
task_config
self
.
task_name
=
task_name
self
.
group_name
=
group_name
...
...
lm_eval/filters/__init__.py
View file @
28c78d30
from
functools
import
partial
from
typing
import
List
from
typing
import
List
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
...
...
@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def
build_filter_ensemble
(
filter_name
:
str
,
components
:
List
[
L
ist
[
str
]]
filter_name
:
str
,
components
:
list
[
Union
[
list
[
dict
],
l
ist
[
str
]]
]
)
->
FilterEnsemble
:
"""
Create a filtering pipeline.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment