Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e30978c7
Commit
e30978c7
authored
May 19, 2025
by
Baber
Browse files
add metric calulation method to configurable task
parent
e9eb451e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
157 additions
and
48 deletions
+157
-48
lm_eval/api/task.py
lm_eval/api/task.py
+40
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+117
-48
No files found.
lm_eval/api/task.py
View file @
e30978c7
...
...
@@ -1757,6 +1757,46 @@ class ConfigurableTask(Task):
f
"num_samples=
{
len
(
self
.
eval_docs
)
}
)"
)
def
calculate_metrics
(
self
,
instances_by_doc_id
,
filter_key
,
samples
,
rank
,
limit
,
world_size
):
"""Calculate metrics for all datapoints in the task.
Args:
instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances.
filter_key (str): The filter key to use for filtered responses.
samples (dict, optional): Dictionary of sample indices to evaluate.
rank (int): The process rank.
limit (int, optional): Limit on number of examples to evaluate.
world_size (int): Total number of processes.
Returns:
list: A list of metrics calculated for each document.
"""
all_metrics
=
[]
# indices = samples.get(self.config.task, None) if samples is not None else None
doc_iterator
=
self
.
doc_iterator
(
rank
=
rank
,
limit
=
limit
,
world_size
=
world_size
,
# samples=indices,
)
for
doc_id
,
doc
in
doc_iterator
:
# doc_id_true = indices[doc_id] if indices else doc_id
requests
=
instances_by_doc_id
[
doc_id
]
metrics
=
[
self
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
req
.
filtered_resps
[
filter_key
]
]
all_metrics
.
extend
(
metrics
)
return
all_metrics
class
MultipleChoiceTask
(
Task
):
OUTPUT_TYPE
=
"loglikelihood"
...
...
lm_eval/evaluator.py
View file @
e30978c7
...
...
@@ -596,56 +596,125 @@ def evaluate(
instances
.
sort
(
key
=
lambda
x
:
x
.
idx
)
# iterate over different filters used
for
filter_key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
indices
=
(
samples
.
get
(
task_output
.
task_name
,
None
)
if
samples
is
not
None
else
None
)
doc_iterator
=
task
.
doc_iterator
(
rank
=
RANK
,
limit
=
limit
,
world_size
=
WORLD_SIZE
,
samples
=
indices
,
)
for
doc_id
,
doc
in
doc_iterator
:
if
indices
:
doc_id_true
=
indices
[
doc_id
]
else
:
doc_id_true
=
doc_id
requests
=
instances_by_doc_id
[
doc_id
]
metrics
:
list
[
dict
]
=
[
task
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
req
.
filtered_resps
[
filter_key
]
]
if
hasattr
(
task
,
"calculate_metrics"
):
# Use the new method if it exists (ConfigurableTask)
metrics
=
task
.
calculate_metrics
(
instances_by_doc_id
=
instances_by_doc_id
,
filter_key
=
filter_key
,
samples
=
samples
,
rank
=
RANK
,
limit
=
limit
,
world_size
=
WORLD_SIZE
,
)
# Add sample logging here too - similar to what's done in the else branch
if
log_samples
:
target
=
task
.
doc_to_target
(
doc
)
example
=
{
"doc_id"
:
doc_id_true
,
"doc"
:
doc
,
"target"
:
target
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
],
"filter"
:
filter_key
,
"metrics"
:
list
({
k
for
m
in
metrics
for
k
in
m
.
keys
()}),
"doc_hash"
:
hash_string
(
json
.
dumps
(
requests
[
0
].
doc
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
,
indices
=
(
samples
.
get
(
task_output
.
task_name
,
None
)
if
samples
is
not
None
else
None
)
doc_iterator
=
task
.
doc_iterator
(
rank
=
RANK
,
limit
=
limit
,
world_size
=
WORLD_SIZE
,
samples
=
indices
,
)
for
doc_id
,
doc
in
doc_iterator
:
doc_id_true
=
indices
[
doc_id
]
if
indices
else
doc_id
requests
=
instances_by_doc_id
[
doc_id
]
if
requests
:
# Make sure there are requests for this doc_id
# Get the metrics for this document
doc_metrics
=
[
task
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
req
.
filtered_resps
[
filter_key
]
]
target
=
task
.
doc_to_target
(
doc
)
example
=
{
"doc_id"
:
doc_id_true
,
"doc"
:
doc
,
"target"
:
target
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
],
"filter"
:
filter_key
,
"metrics"
:
doc_metrics
,
"doc_hash"
:
hash_string
(
json
.
dumps
(
requests
[
0
].
doc
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
,
)
),
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
[
0
]),
"target_hash"
:
hash_string
(
str
(
target
)),
}
task_output
.
logged_samples
.
append
(
example
)
# Process all metrics returned from calculate_metrics
for
x
in
metrics
:
for
metric
,
value
in
x
.
items
():
task_output
.
sample_metrics
[(
metric
,
filter_key
)].
append
(
value
)
else
:
# Fall back to the original approach for non-ConfigurableTask instances
indices
=
(
samples
.
get
(
task_output
.
task_name
,
None
)
if
samples
is
not
None
else
None
)
doc_iterator
=
task
.
doc_iterator
(
rank
=
RANK
,
limit
=
limit
,
world_size
=
WORLD_SIZE
,
samples
=
indices
,
)
for
doc_id
,
doc
in
doc_iterator
:
if
indices
:
doc_id_true
=
indices
[
doc_id
]
else
:
doc_id_true
=
doc_id
requests
=
instances_by_doc_id
[
doc_id
]
metrics
:
list
[
dict
]
=
[
task
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
req
.
filtered_resps
[
filter_key
]
]
if
log_samples
:
target
=
task
.
doc_to_target
(
doc
)
example
=
{
"doc_id"
:
doc_id_true
,
"doc"
:
doc
,
"target"
:
target
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
],
"filter"
:
filter_key
,
"metrics"
:
metrics
,
"doc_hash"
:
hash_string
(
json
.
dumps
(
requests
[
0
].
doc
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
,
)
),
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
[
0
]),
"target_hash"
:
hash_string
(
str
(
target
)),
}
example
.
update
({
"metrics"
:
metrics
})
task_output
.
logged_samples
.
append
(
example
)
for
x
in
metrics
:
for
metric
,
value
in
x
.
items
():
task_output
.
sample_metrics
[(
metric
,
filter_key
)].
append
(
value
)
),
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
[
0
]),
"target_hash"
:
hash_string
(
str
(
target
)),
}
example
.
update
(
metrics
)
task_output
.
logged_samples
.
append
(
example
)
for
metric
,
value
in
metrics
.
items
():
task_output
.
sample_metrics
[(
metric
,
filter_key
)].
append
(
value
)
if
WORLD_SIZE
>
1
:
# if multigpu, then gather data across all ranks to rank 0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment