Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
57b91fdb
Commit
57b91fdb
authored
May 24, 2025
by
Baber
Browse files
refactor: enhance metric handling and aggregation logic
parent
911cae22
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
70 additions
and
21 deletions
+70
-21
lm_eval/api/schemas.py
lm_eval/api/schemas.py
+40
-0
lm_eval/api/task.py
lm_eval/api/task.py
+5
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+15
-15
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+9
-1
lm_eval/tasks/gsm8k/gsm8k.yaml
lm_eval/tasks/gsm8k/gsm8k.yaml
+1
-3
No files found.
lm_eval/api/schemas.py
View file @
57b91fdb
...
...
@@ -19,6 +19,9 @@ class GenerateInput:
else
iter
((
self
.
prompt
,
self
.
gen_kwargs
,
self
.
multimodal_arg
))
)
def
__getitem__
(
self
,
item
:
int
):
return
[
self
.
prompt
,
self
.
gen_kwargs
][
item
]
@
dataclass
class
GenerateOutput
:
...
...
@@ -54,3 +57,40 @@ class LoglikelihoodOutput:
def
__iter__
(
self
):
return
iter
((
self
.
loglikelihood
,
self
.
is_greedy
))
@
dataclass
class
MetricResult
:
"""
Outputs for the metric function.
"""
doc_id
:
str
|
int
|
None
scores
:
list
[
dict
[
str
,
float
]]
|
None
filter_key
:
str
=
None
metric_name
:
str
=
None
metadata
:
Optional
[
dict
]
=
None
def
__iter__
(
self
):
if
self
.
scores
is
None
:
return
iter
([])
# Group values by metric key
grouped
=
{}
for
score_dict
in
self
.
scores
:
for
key
,
value
in
score_dict
.
items
():
if
key
not
in
grouped
:
grouped
[
key
]
=
[]
grouped
[
key
].
append
(
value
)
# Return iterator of (key, list[values]) pairs
return
iter
(
grouped
.
items
())
def
get_metric_results
(
self
,
metric_key
)
->
list
[
float
]:
if
self
.
scores
is
None
:
return
[]
return
[
score_dict
[
metric_key
]
for
score_dict
in
self
.
scores
if
metric_key
in
score_dict
]
lm_eval/api/task.py
View file @
57b91fdb
...
...
@@ -37,7 +37,7 @@ from lm_eval.api.registry import (
get_metric_aggregation
,
is_higher_better
,
)
from
lm_eval.api.schemas
import
GenerateInput
,
LoglikelihoodInput
from
lm_eval.api.schemas
import
GenerateInput
,
LoglikelihoodInput
,
MetricResult
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.prompts
import
get_prompt
...
...
@@ -98,6 +98,7 @@ class TaskConfig(dict):
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
repeat_agg
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
...
...
@@ -1818,7 +1819,9 @@ class ConfigurableTask(Task):
)
]
all_metrics
[
filter_key
].
append
(
metrics
)
all_metrics
[
filter_key
].
append
(
MetricResult
(
scores
=
metrics
,
doc_id
=
doc_id
,
filter_key
=
filter_key
)
)
return
all_metrics
...
...
lm_eval/evaluator.py
View file @
57b91fdb
...
...
@@ -637,11 +637,13 @@ def evaluate(
requests
=
instances_by_doc_id
[
doc_id
]
if
requests
:
# Make sure there are requests for this doc_id
# Get the metrics for this document
doc_metrics
=
[
task
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
req
.
filtered_resps
[
filter_key
]
]
# doc_metrics = [
# task.process_results(doc, response)
# for req in requests
# for response in req.filtered_resps[filter_key]
# ]
# TODO: doc_metrics is flat list with floats and not clear if we have multiple emtircs
doc_metrics
=
[
y
for
y
in
metrics
[
filter_key
][
0
]]
target
=
task
.
doc_to_target
(
doc
)
example
=
{
...
...
@@ -672,18 +674,16 @@ def evaluate(
# Process all metrics returned from calculate_metrics
for
filter_key
in
metrics
:
for
sample_metric
in
metrics
[
filter_key
]:
for
metric_key
,
value
in
sample_metric
:
task_output
.
sample_metrics
[(
metric_key
,
filter_key
)].
append
(
# we get a list of metric results
# [MetricResult(doc_id=0, scores=[{'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}], filter_key='strict-match', metric_name=None, metadata=None),
# MetricResult(doc_id=1, scores=[{'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}], filter_key='strict-match', metric_name=None, metadata=None)]
for
m_samples
in
metrics
[
filter_key
]:
# m_samples is a MetricResult object
# m_samples.scores is a list of dicts
for
metric
,
value
in
m_samples
:
task_output
.
sample_metrics
[(
metric
,
filter_key
)].
append
(
value
)
# metrics is a list of dictionaries, each containing metric names and their values
# e.g., [{"accuracy": 0.9}, {"f1": 0.8}]
# We need to iterate through each dictionary and extract the metric names and values
# for x in metrics:
# for metric, value in x.items():
# task_output.sample_metrics[(metric, filter_key)].append(value)
else
:
# Fall back to the original approach for non-ConfigurableTask instances
indices
=
(
...
...
lm_eval/evaluator_utils.py
View file @
57b91fdb
...
...
@@ -111,7 +111,15 @@ class TaskOutput:
# TODO: Handle this better and allow other aggregate functions other than mean.
agg_fn
=
mean
metric_key
=
f
"
{
metric
}
,
{
filter_key
}
"
self
.
agg_metrics
[
metric_key
]
=
agg_fn
(
items
)
# Handle multiple repeats: items is now list[list[float]]
if
items
and
isinstance
(
items
[
0
],
list
):
# Apply aggregation function to each repeat
self
.
agg_metrics
[
metric_key
]
=
[
agg_fn
(
repeat
)
for
repeat
in
zip
(
*
items
)
]
else
:
# Backward compatibility: items is list[float]
self
.
agg_metrics
[
metric_key
]
=
agg_fn
(
items
)
self
.
sample_len
=
len
(
items
)
# TODO: same sample size for each metric?
if
isinstance
(
bootstrap_iters
,
int
):
stderr_fn
=
stderr_for_metric
(
...
...
lm_eval/tasks/gsm8k/gsm8k.yaml
View file @
57b91fdb
...
...
@@ -27,19 +27,17 @@ generation_kwargs:
-
"
<|im_end|>"
do_sample
:
false
temperature
:
0.0
repeats
:
1
repeats
:
3
num_fewshot
:
5
filter_list
:
-
name
:
"
strict-match"
filter
:
-
function
:
"
regex"
regex_pattern
:
"
####
(
\\
-?[0-9
\\
.
\\
,]+)"
-
function
:
"
take_first"
-
name
:
"
flexible-extract"
filter
:
-
function
:
"
regex"
group_select
:
-1
regex_pattern
:
"
(-?[$0-9.,]{2,})|(-?[0-9]+)"
-
function
:
"
take_first"
metadata
:
version
:
3.0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment