Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
911cae22
"vscode:/vscode.git/clone" did not exist on "c3fe0550a70a807ffef5c0c49573624abd52d813"
Commit
911cae22
authored
May 23, 2025
by
Baber
Browse files
TODO!
parent
69e95b87
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
34 deletions
+66
-34
lm_eval/api/task.py
lm_eval/api/task.py
+35
-21
lm_eval/evaluator.py
lm_eval/evaluator.py
+31
-13
No files found.
lm_eval/api/task.py
View file @
911cae22
import
abc
import
abc
import
ast
import
ast
import
collections
import
logging
import
logging
import
random
import
random
import
re
import
re
...
@@ -1768,8 +1769,14 @@ class ConfigurableTask(Task):
...
@@ -1768,8 +1769,14 @@ class ConfigurableTask(Task):
)
)
def
calculate_metrics
(
def
calculate_metrics
(
self
,
instances_by_doc_id
,
filter_key
,
samples
,
rank
,
limit
,
world_size
self
,
)
->
list
[
list
[
dict
]]:
instances_by_doc_id
,
filter_keys
=
None
,
samples
=
None
,
rank
=
1
,
limit
=
None
,
world_size
=
1
,
)
->
dict
[
str
,
list
[
dict
]]:
"""Calculate metrics for all datapoints in the task.
"""Calculate metrics for all datapoints in the task.
Args:
Args:
...
@@ -1783,28 +1790,35 @@ class ConfigurableTask(Task):
...
@@ -1783,28 +1790,35 @@ class ConfigurableTask(Task):
Returns:
Returns:
list: A list of metrics calculated for each document.
list: A list of metrics calculated for each document.
"""
"""
all_metrics
=
[]
if
filter_keys
is
None
:
filter_keys
=
[
x
.
name
for
x
in
self
.
_filters
]
if
isinstance
(
filter_keys
,
str
):
filter_keys
=
[
filter_keys
]
all_metrics
=
collections
.
defaultdict
(
list
)
# indices = samples.get(self.config.task, None) if samples is not None else None
# indices = samples.get(self.config.task, None) if samples is not None else None
for
filter_key
in
filter_keys
:
doc_iterator
=
self
.
doc_iterator
(
rank
=
rank
,
limit
=
limit
,
world_size
=
world_size
,
# samples=indices,
)
doc_iterator
=
self
.
doc_iterator
(
for
doc_id
,
doc
in
doc_iterator
:
rank
=
rank
,
# doc_id_true = indices[doc_id] if indices else doc_id
limit
=
limit
,
requests
=
instances_by_doc_id
[
doc_id
]
world_size
=
world_size
,
# samples=indices,
metrics
=
[
)
self
.
process_results
(
doc
,
response
)
for
req
in
requests
for
doc_id
,
doc
in
doc_iterator
:
for
response
in
(
# doc_id_true = indices[doc_id] if indices else doc_id
req
.
filtered_resps
[
filter_key
]
requests
=
instances_by_doc_id
[
doc_id
]
if
isinstance
(
req
.
filtered_resps
[
filter_key
],
list
)
else
[
req
.
filtered_resps
[
filter_key
]]
metrics
:
list
[
list
[
dict
]]
=
[
)
self
.
process_results
(
doc
,
response
)
]
for
req
in
requests
for
response
in
req
.
filtered_resps
[
filter_key
]
]
# TODO: This turns metrics into a list of lists of dicts rather than flat list.
all_metrics
[
filter_key
].
append
(
metrics
)
all_metrics
.
append
(
metrics
)
return
all_metrics
return
all_metrics
...
...
lm_eval/evaluator.py
View file @
911cae22
...
@@ -38,7 +38,7 @@ from lm_eval.utils import (
...
@@ -38,7 +38,7 @@ from lm_eval.utils import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
ConfigurableTask
,
Task
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -585,7 +585,7 @@ def evaluate(
...
@@ -585,7 +585,7 @@ def evaluate(
### Postprocess outputs ###
### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for
task_output
,
limit
in
zip
(
eval_tasks
,
limits
):
for
task_output
,
limit
in
zip
(
eval_tasks
,
limits
):
task
=
task_output
.
task
task
:
ConfigurableTask
=
task_output
.
task
task
.
apply_filters
()
task
.
apply_filters
()
### Collect values of metrics on all datapoints ###
### Collect values of metrics on all datapoints ###
...
@@ -599,17 +599,25 @@ def evaluate(
...
@@ -599,17 +599,25 @@ def evaluate(
for
instances
in
instances_by_doc_id
.
values
():
for
instances
in
instances_by_doc_id
.
values
():
instances
.
sort
(
key
=
lambda
x
:
x
.
idx
)
instances
.
sort
(
key
=
lambda
x
:
x
.
idx
)
# iterate over different filters used
# iterate over different filters used
if
hasattr
(
task
,
"calculate_metrics"
):
metrics
=
task
.
calculate_metrics
(
instances_by_doc_id
=
instances_by_doc_id
,
samples
=
samples
,
rank
=
RANK
,
limit
=
limit
,
world_size
=
WORLD_SIZE
,
)
for
filter_key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
for
filter_key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
if
hasattr
(
task
,
"calculate_metrics"
):
if
hasattr
(
task
,
"calculate_metrics"
):
# Use the new method if it exists (ConfigurableTask)
# Use the new method if it exists (ConfigurableTask)
metrics
=
task
.
calculate_metrics
(
#
metrics = task.calculate_metrics(
instances_by_doc_id
=
instances_by_doc_id
,
#
instances_by_doc_id=instances_by_doc_id,
filter_key
=
filter_key
,
#
filter_key
s
=filter_key,
samples
=
samples
,
#
samples=samples,
rank
=
RANK
,
#
rank=RANK,
limit
=
limit
,
#
limit=limit,
world_size
=
WORLD_SIZE
,
#
world_size=WORLD_SIZE,
)
#
)
# Add sample logging here too - similar to what's done in the else branch
# Add sample logging here too - similar to what's done in the else branch
if
log_samples
:
if
log_samples
:
...
@@ -663,9 +671,19 @@ def evaluate(
...
@@ -663,9 +671,19 @@ def evaluate(
task_output
.
logged_samples
.
append
(
example
)
task_output
.
logged_samples
.
append
(
example
)
# Process all metrics returned from calculate_metrics
# Process all metrics returned from calculate_metrics
for
x
in
metrics
:
for
filter_key
in
metrics
:
for
metric
,
value
in
x
.
items
():
for
sample_metric
in
metrics
[
filter_key
]:
task_output
.
sample_metrics
[(
metric
,
filter_key
)].
append
(
value
)
for
metric_key
,
value
in
sample_metric
:
task_output
.
sample_metrics
[(
metric_key
,
filter_key
)].
append
(
value
)
# metrics is a list of dictionaries, each containing metric names and their values
# e.g., [{"accuracy": 0.9}, {"f1": 0.8}]
# We need to iterate through each dictionary and extract the metric names and values
# for x in metrics:
# for metric, value in x.items():
# task_output.sample_metrics[(metric, filter_key)].append(value)
else
:
else
:
# Fall back to the original approach for non-ConfigurableTask instances
# Fall back to the original approach for non-ConfigurableTask instances
indices
=
(
indices
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment