Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ba1d4483
Commit
ba1d4483
authored
May 25, 2025
by
Baber
Browse files
refactor: streamline metric calculations and enhance logging
parent
57b91fdb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
35 deletions
+37
-35
lm_eval/api/schemas.py
lm_eval/api/schemas.py
+6
-0
lm_eval/api/task.py
lm_eval/api/task.py
+16
-7
lm_eval/evaluator.py
lm_eval/evaluator.py
+8
-24
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+6
-3
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+1
-1
No files found.
lm_eval/api/schemas.py
View file @
ba1d4483
...
...
@@ -94,3 +94,9 @@ class MetricResult:
for
score_dict
in
self
.
scores
if
metric_key
in
score_dict
]
@
property
def
metric_keys
(
self
)
->
list
[
str
]:
if
self
.
scores
is
None
:
return
[]
return
list
(
self
.
scores
[
0
].
keys
())
if
self
.
scores
else
[]
lm_eval/api/task.py
View file @
ba1d4483
...
...
@@ -887,7 +887,7 @@ class ConfigurableTask(Task):
eval_logger
.
debug
(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
)
#
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self
.
_filters
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
if
self
.
config
.
use_prompt
is
not
None
:
eval_logger
.
info
(
f
"loading prompt
{
self
.
config
.
use_prompt
}
"
)
...
...
@@ -1771,13 +1771,13 @@ class ConfigurableTask(Task):
def
calculate_metrics
(
self
,
instances_by_doc_id
,
instances_by_doc_id
=
None
,
filter_keys
=
None
,
samples
=
None
,
rank
=
1
,
limit
=
None
,
world_size
=
1
,
)
->
dict
[
str
,
list
[
dict
]]:
)
->
Optional
[
dict
[
str
,
list
[
dict
]]
]
:
"""Calculate metrics for all datapoints in the task.
Args:
...
...
@@ -1791,12 +1791,23 @@ class ConfigurableTask(Task):
Returns:
list: A list of metrics calculated for each document.
"""
if
not
self
.
_instances
:
return
from
collections
import
defaultdict
if
filter_keys
is
None
:
filter_keys
=
[
x
.
name
for
x
in
self
.
_filters
]
filter_keys
=
(
[
x
.
name
for
x
in
self
.
_filters
]
if
hasattr
(
self
,
"_filters"
)
else
[
"none"
]
)
if
isinstance
(
filter_keys
,
str
):
filter_keys
=
[
filter_keys
]
if
not
instances_by_doc_id
:
instances_by_doc_id
=
defaultdict
(
list
)
for
instance
in
self
.
instances
:
instances_by_doc_id
[
instance
.
doc_id
].
append
(
instance
)
all_metrics
=
collections
.
defaultdict
(
list
)
# indices = samples.get(self.config.task, None) if samples is not None else None
for
filter_key
in
filter_keys
:
doc_iterator
=
self
.
doc_iterator
(
rank
=
rank
,
...
...
@@ -1808,7 +1819,6 @@ class ConfigurableTask(Task):
for
doc_id
,
doc
in
doc_iterator
:
# doc_id_true = indices[doc_id] if indices else doc_id
requests
=
instances_by_doc_id
[
doc_id
]
metrics
=
[
self
.
process_results
(
doc
,
response
)
for
req
in
requests
...
...
@@ -1818,7 +1828,6 @@ class ConfigurableTask(Task):
else
[
req
.
filtered_resps
[
filter_key
]]
)
]
all_metrics
[
filter_key
].
append
(
MetricResult
(
scores
=
metrics
,
doc_id
=
doc_id
,
filter_key
=
filter_key
)
)
...
...
lm_eval/evaluator.py
View file @
ba1d4483
...
...
@@ -609,16 +609,6 @@ def evaluate(
)
for
filter_key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
if
hasattr
(
task
,
"calculate_metrics"
):
# Use the new method if it exists (ConfigurableTask)
# metrics = task.calculate_metrics(
# instances_by_doc_id=instances_by_doc_id,
# filter_keys=filter_key,
# samples=samples,
# rank=RANK,
# limit=limit,
# world_size=WORLD_SIZE,
# )
# Add sample logging here too - similar to what's done in the else branch
if
log_samples
:
indices
=
(
...
...
@@ -636,15 +626,7 @@ def evaluate(
doc_id_true
=
indices
[
doc_id
]
if
indices
else
doc_id
requests
=
instances_by_doc_id
[
doc_id
]
if
requests
:
# Make sure there are requests for this doc_id
# Get the metrics for this document
# doc_metrics = [
# task.process_results(doc, response)
# for req in requests
# for response in req.filtered_resps[filter_key]
# ]
# TODO: doc_metrics is flat list with floats and not clear if we have multiple emtircs
doc_metrics
=
[
y
for
y
in
metrics
[
filter_key
][
0
]]
doc_metrics
=
metrics
[
filter_key
][
doc_id_true
].
metric_keys
target
=
task
.
doc_to_target
(
doc
)
example
=
{
"doc_id"
:
doc_id_true
,
...
...
@@ -670,16 +652,18 @@ def evaluate(
),
"target_hash"
:
hash_string
(
str
(
target
)),
}
example
.
update
(
{
metrics
[
filter_key
][
doc_id_true
].
metric_keys
[
0
]:
metrics
[
filter_key
][
doc_id_true
]
}
)
task_output
.
logged_samples
.
append
(
example
)
# Process all metrics returned from calculate_metrics
for
filter_key
in
metrics
:
# we get a list of metric results
# [MetricResult(doc_id=0, scores=[{'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}], filter_key='strict-match', metric_name=None, metadata=None),
# MetricResult(doc_id=1, scores=[{'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}], filter_key='strict-match', metric_name=None, metadata=None)]
for
m_samples
in
metrics
[
filter_key
]:
# m_samples is a MetricResult object
# m_samples.scores is a list of dicts
for
metric
,
value
in
m_samples
:
task_output
.
sample_metrics
[(
metric
,
filter_key
)].
append
(
value
...
...
lm_eval/evaluator_utils.py
View file @
ba1d4483
...
...
@@ -128,9 +128,12 @@ class TaskOutput:
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
self
.
agg_metrics
[
f
"
{
metric
}
_stderr,
{
filter_key
}
"
]
=
(
stderr_fn
(
items
)
if
(
stderr_fn
and
len
(
items
)
>
1
)
else
"N/A"
)
# TODO: what's the best way to calculate repeat stderr
# maybe mean/sample then bootstrap?
self
.
agg_metrics
[
f
"
{
metric
}
_stderr,
{
filter_key
}
"
]
=
[
(
stderr_fn
(
item
)
if
(
stderr_fn
and
len
(
item
)
>
1
)
else
"N/A"
)
for
item
in
zip
(
*
items
)
][
0
]
else
:
raise
ValueError
(
f
"Received bootstrap_iters '
{
bootstrap_iters
}
' but expected an integer. Set to 0 to turn off stderr calculations."
...
...
lm_eval/filters/selection.py
View file @
ba1d4483
...
...
@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return
map
(
lambda
r
:
r
[
0
]
,
resps
)
return
map
(
lambda
r
:
r
,
resps
)
@
register_filter
(
"take_first_k"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment