Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9093b1a6
Commit
9093b1a6
authored
Jul 07, 2025
by
Baber
Browse files
move metric calculation to task
parent
51ab86ff
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
153 additions
and
178 deletions
+153
-178
lm_eval/api/schemas.py
lm_eval/api/schemas.py
+3
-3
lm_eval/api/task.py
lm_eval/api/task.py
+81
-38
lm_eval/evaluator.py
lm_eval/evaluator.py
+18
-133
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+1
-1
lm_eval/utils.py
lm_eval/utils.py
+50
-3
No files found.
lm_eval/api/schemas.py
View file @
9093b1a6
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
,
Union
# @dataclass
# @dataclass
...
@@ -64,11 +64,11 @@ class MetricResult:
...
@@ -64,11 +64,11 @@ class MetricResult:
Outputs for the metric function.
Outputs for the metric function.
"""
"""
doc_id
:
str
|
int
|
None
doc_id
:
Union
[
str
,
int
]
scores
:
list
[
dict
[
str
,
float
]]
|
dict
filter_key
:
str
=
None
filter_key
:
str
=
None
metric_name
:
str
=
None
metric_name
:
str
=
None
metadata
:
Optional
[
dict
]
=
None
metadata
:
Optional
[
dict
]
=
None
scores
:
Union
[
list
[
dict
[
str
,
float
]],
dict
]
=
None
def
__iter__
(
self
):
def
__iter__
(
self
):
if
self
.
scores
is
None
:
if
self
.
scores
is
None
:
...
...
lm_eval/api/task.py
View file @
9093b1a6
...
@@ -8,8 +8,6 @@ from collections.abc import Callable
...
@@ -8,8 +8,6 @@ from collections.abc import Callable
from
copy
import
deepcopy
from
copy
import
deepcopy
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
inspect
import
getsource
from
itertools
import
groupby
from
operator
import
attrgetter
from
typing
import
(
from
typing
import
(
Any
,
Any
,
Dict
,
Dict
,
...
@@ -30,7 +28,12 @@ from tqdm import tqdm
...
@@ -30,7 +28,12 @@ from tqdm import tqdm
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api
import
samplers
from
lm_eval.api
import
samplers
from
lm_eval.api.instance
import
Instance
,
OutputType
from
lm_eval.api.instance
import
Instance
,
OutputType
from
lm_eval.api.metrics
import
bits_per_byte
,
mean
,
weighted_perplexity
from
lm_eval.api.metrics
import
(
bits_per_byte
,
mean
,
stderr_for_metric
,
weighted_perplexity
,
)
from
lm_eval.api.registry
import
(
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
...
@@ -39,10 +42,10 @@ from lm_eval.api.registry import (
...
@@ -39,10 +42,10 @@ from lm_eval.api.registry import (
get_metric_aggregation
,
get_metric_aggregation
,
is_higher_better
,
is_higher_better
,
)
)
from
lm_eval.api.schemas
import
MetricResult
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.prompts
import
get_prompt
from
lm_eval.prompts
import
get_prompt
from
lm_eval.utils
import
create_sample_log
,
pass_at_k
ALL_OUTPUT_TYPES
=
[
ALL_OUTPUT_TYPES
=
[
...
@@ -1774,19 +1777,22 @@ class ConfigurableTask(Task):
...
@@ -1774,19 +1777,22 @@ class ConfigurableTask(Task):
def
calculate_metrics
(
def
calculate_metrics
(
self
,
self
,
instances_by_doc_id
=
None
,
requests
:
list
[
Instance
]
=
None
,
filter_keys
=
None
,
filter_keys
:
list
[
str
]
=
None
,
samples
=
None
,
indices
:
list
[
int
]
=
None
,
rank
=
1
,
rank
:
int
=
1
,
limit
=
None
,
limit
:
int
=
None
,
world_size
=
1
,
world_size
:
int
=
1
,
)
->
list
[
MetricResult
]:
log_samples
:
bool
=
False
,
)
->
tuple
[
Optional
[
dict
[
tuple
[
str
,
str
],
list
[
list
[
float
]]]],
Optional
[
list
[
dict
]]
]:
"""Calculate metrics for all datapoints in the task.
"""Calculate metrics for all datapoints in the task.
Args:
Args:
instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances.
instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances.
filter_key (str): The filter key to use for filtered responses.
filter_key (str): The filter key to use for filtered responses.
sampl
es (dict, optional): Dictionary of sample indices to evaluate.
indic
es (dict, optional): Dictionary of sample indices to evaluate.
rank (int): The process rank.
rank (int): The process rank.
limit (int, optional): Limit on number of examples to evaluate.
limit (int, optional): Limit on number of examples to evaluate.
world_size (int): Total number of processes.
world_size (int): Total number of processes.
...
@@ -1794,8 +1800,20 @@ class ConfigurableTask(Task):
...
@@ -1794,8 +1800,20 @@ class ConfigurableTask(Task):
Returns:
Returns:
list: A list of metrics calculated for each document.
list: A list of metrics calculated for each document.
"""
"""
if
not
self
.
_instances
:
if
not
requests
and
not
self
.
instances
:
return
print
(
"sent results"
)
return
None
,
None
### Collect values of metrics on all datapoints ###
# Pre-process task.instances to group by doc_id
instances_by_doc_id
=
defaultdict
(
list
)
for
instance
in
self
.
instances
:
instances_by_doc_id
[
instance
.
doc_id
].
append
(
instance
)
# Sort instances within each group
for
instances
in
instances_by_doc_id
.
values
():
instances
.
sort
(
key
=
lambda
x
:
x
.
idx
)
_all_metrics
=
defaultdict
(
list
)
_samples
=
[]
if
log_samples
else
None
if
filter_keys
is
None
:
if
filter_keys
is
None
:
filter_keys
=
(
filter_keys
=
(
...
@@ -1805,23 +1823,18 @@ class ConfigurableTask(Task):
...
@@ -1805,23 +1823,18 @@ class ConfigurableTask(Task):
)
)
if
isinstance
(
filter_keys
,
str
):
if
isinstance
(
filter_keys
,
str
):
filter_keys
=
[
filter_keys
]
filter_keys
=
[
filter_keys
]
if
not
instances_by_doc_id
:
instances_by_doc_id
=
defaultdict
(
list
)
for
instance
in
self
.
instances
:
instances_by_doc_id
[
instance
.
doc_id
].
append
(
instance
)
# all_metrics = collections.defaultdict(list)
all_metrics
=
[]
for
filter_key
in
filter_keys
:
for
filter_key
in
filter_keys
:
doc_iterator
=
self
.
doc_iterator
(
doc_iterator
=
self
.
doc_iterator
(
rank
=
rank
,
rank
=
rank
,
limit
=
limit
,
limit
=
limit
,
world_size
=
world_size
,
world_size
=
world_size
,
#
samples=indices,
samples
=
indices
,
)
)
for
doc_id
,
doc
in
doc_iterator
:
for
doc_id
,
doc
in
doc_iterator
:
# doc_id_true = indices[doc_id] if indices else doc_id
_doc_id_true
=
indices
[
doc_id
]
if
indices
else
doc_id
requests
=
instances_by_doc_id
[
doc_id
]
_sample_metric
=
defaultdict
(
list
)
requests
=
instances_by_doc_id
[
_doc_id_true
]
if
len
(
requests
)
>
1
:
if
len
(
requests
)
>
1
:
# if one doc has multiple instances then calculate metric together
# if one doc has multiple instances then calculate metric together
metrics
=
self
.
process_results
(
metrics
=
self
.
process_results
(
...
@@ -1837,24 +1850,54 @@ class ConfigurableTask(Task):
...
@@ -1837,24 +1850,54 @@ class ConfigurableTask(Task):
else
[
req
.
filtered_resps
[
filter_key
]]
else
[
req
.
filtered_resps
[
filter_key
]]
)
)
]
]
all_metrics
.
append
(
for
metric
in
metrics
:
MetricResult
(
scores
=
metrics
,
doc_id
=
doc_id
,
filter_key
=
filter_key
)
for
k
,
v
in
metric
.
items
():
)
_sample_metric
[
k
].
append
(
v
)
if
log_samples
:
_samples
.
append
(
create_sample_log
(
doc
=
doc
,
doc_id
=
_doc_id_true
,
target
=
self
.
doc_to_target
(
doc
),
requests
=
requests
,
metric_names
=
metrics
,
filter_key
=
filter_key
,
)
)
for
metric_name
,
_score
in
_sample_metric
.
items
():
_all_metrics
[(
metric_name
,
filter_key
)].
append
(
_score
)
return
all_metrics
return
_
all_metrics
,
_samples
@
staticmethod
def
compute_agg_metrics
(
def
compute_agg_metrics
(
self
,
metric_results
:
list
[
MetricResult
]):
self
,
y_sorted
=
sorted
(
metric_results
,
key
=
attrgetter
(
"filter_key"
,
"metric_name"
))
metric_results
:
dict
[
tuple
[
str
,
str
],
list
[
list
[
float
]]],
bootstrap_iters
:
int
=
1000
,
groups
=
{
):
key
:
list
(
agg_metrics
=
defaultdict
(
list
)
map
(
list
,
zip
(
*
((
d
[
it
.
metric_name
]
for
d
in
it
.
scores
)
for
it
in
g
)))
for
(
metric_name
,
filter_key
),
scores
in
metric_results
.
items
():
)
agg_fn
=
self
.
aggregation
()[
metric_name
]
for
key
,
g
in
groupby
(
y_sorted
,
key
=
attrgetter
(
"filter_key"
,
"metric_name"
))
metric_key
=
f
"
{
metric_name
}
,
{
filter_key
}
"
}
self
.
repeat_metric
=
pass_at_k
repeats
=
[
self
.
repeat_metric
(
len
(
x
),
x
.
count
(
1
),
k
=
x
.
count
(
1
)
-
1
)
for
x
in
scores
]
repeat_agg
=
np
.
mean
(
repeats
)
agg_metrics
[
metric_key
]
=
[
agg_fn
(
items
)
for
items
in
zip
(
*
scores
)]
if
isinstance
(
bootstrap_iters
,
int
):
stderr_fn
=
stderr_for_metric
(
metric
=
agg_fn
,
bootstrap_iters
=
min
(
bootstrap_iters
,
100
)
if
metric_name
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
agg_metrics
[
f
"
{
metric_name
}
_stderr,
{
filter_key
}
"
]
=
[
(
stderr_fn
(
item
)
if
(
stderr_fn
and
len
(
item
)
>
1
)
else
"N/A"
)
for
item
in
zip
(
*
scores
)
][
0
]
agg_metrics
[
f
"
{
metric_key
}
_repeat"
]
=
[
repeat_agg
]
return
group
s
return
agg_metric
s
class
MultipleChoiceTask
(
Task
):
class
MultipleChoiceTask
(
Task
):
...
...
lm_eval/evaluator.py
View file @
9093b1a6
import
itertools
import
itertools
import
json
import
logging
import
logging
import
random
import
random
import
time
import
time
...
@@ -28,8 +27,6 @@ from lm_eval.loggers import EvaluationTracker
...
@@ -28,8 +27,6 @@ from lm_eval.loggers import EvaluationTracker
from
lm_eval.loggers.utils
import
add_env_info
,
add_tokenizer_info
,
get_git_commit_hash
from
lm_eval.loggers.utils
import
add_env_info
,
add_tokenizer_info
,
get_git_commit_hash
from
lm_eval.tasks
import
TaskManager
,
get_task_dict
from
lm_eval.tasks
import
TaskManager
,
get_task_dict
from
lm_eval.utils
import
(
from
lm_eval.utils
import
(
handle_non_serializable
,
hash_string
,
positional_deprecated
,
positional_deprecated
,
setup_logging
,
setup_logging
,
simple_parse_args_string
,
simple_parse_args_string
,
...
@@ -592,135 +589,23 @@ def evaluate(
...
@@ -592,135 +589,23 @@ def evaluate(
# # unpack results and sort back in order and return control to Task
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
# TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id
# Pre-process task.instances to group by doc_id
instances_by_doc_id
=
defaultdict
(
list
)
#
instances_by_doc_id = defaultdict(list)
for
instance
in
task
.
instances
:
#
for instance in task.instances:
instances_by_doc_id
[
instance
.
doc_id
].
append
(
instance
)
#
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
#
#
Sort instances within each group
for
instances
in
instances_by_doc_id
.
values
():
#
for instances in instances_by_doc_id.values():
instances
.
sort
(
key
=
lambda
x
:
x
.
idx
)
#
instances.sort(key=lambda x: x.idx)
# iterate over different filters used
# iterate over different filters used
if
hasattr
(
task
,
"calculate_metrics"
):
_metrics
,
samples
=
task
.
calculate_metrics
(
metrics
=
task
.
calculate_metrics
(
indices
=
samples
,
instances_by_doc_id
=
instances_by_doc_id
,
rank
=
RANK
,
samples
=
samples
,
limit
=
limit
,
rank
=
RANK
,
world_size
=
WORLD_SIZE
,
limit
=
limit
,
)
world_size
=
WORLD_SIZE
,
task_output
.
agg_metrics
=
task
.
compute_agg_metrics
(
_metrics
)
)
task_output
.
sample_metrics
=
_metrics
for
filter_key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
if
log_samples
:
if
hasattr
(
task
,
"calculate_metrics"
):
task_output
.
logged_samples
=
samples
# Add sample logging here too - similar to what's done in the else branch
if
log_samples
:
indices
=
(
samples
.
get
(
task_output
.
task_name
,
None
)
if
samples
is
not
None
else
None
)
doc_iterator
=
task
.
doc_iterator
(
rank
=
RANK
,
limit
=
limit
,
world_size
=
WORLD_SIZE
,
samples
=
indices
,
)
for
doc_id
,
doc
in
doc_iterator
:
doc_id_true
=
indices
[
doc_id
]
if
indices
else
doc_id
requests
=
instances_by_doc_id
[
doc_id
]
if
requests
:
# Make sure there are requests for this doc_id
doc_metrics
=
metrics
[
filter_key
][
doc_id_true
].
metric_keys
target
=
task
.
doc_to_target
(
doc
)
example
=
{
"doc_id"
:
doc_id_true
,
"doc"
:
doc
,
"target"
:
target
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
],
"filter"
:
filter_key
,
"metrics"
:
doc_metrics
,
"doc_hash"
:
hash_string
(
json
.
dumps
(
requests
[
0
].
doc
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
,
)
),
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
[
0
]),
"target_hash"
:
hash_string
(
str
(
target
)),
}
example
.
update
(
{
metrics
[
filter_key
][
doc_id_true
].
metric_keys
[
0
]:
metrics
[
filter_key
][
doc_id_true
]
}
)
task_output
.
logged_samples
.
append
(
example
)
# Process all metrics returned from calculate_metrics
for
filter_key
in
metrics
:
for
m_samples
in
metrics
[
filter_key
]:
for
metric
,
value
in
m_samples
:
task_output
.
sample_metrics
[(
metric
,
filter_key
)].
append
(
value
)
else
:
# Fall back to the original approach for non-ConfigurableTask instances
indices
=
(
samples
.
get
(
task_output
.
task_name
,
None
)
if
samples
is
not
None
else
None
)
doc_iterator
=
task
.
doc_iterator
(
rank
=
RANK
,
limit
=
limit
,
world_size
=
WORLD_SIZE
,
samples
=
indices
,
)
for
doc_id
,
doc
in
doc_iterator
:
if
indices
:
doc_id_true
=
indices
[
doc_id
]
else
:
doc_id_true
=
doc_id
requests
=
instances_by_doc_id
[
doc_id
]
metrics
:
list
[
dict
]
=
[
task
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
req
.
filtered_resps
[
filter_key
]
]
if
log_samples
:
target
=
task
.
doc_to_target
(
doc
)
example
=
{
"doc_id"
:
doc_id_true
,
"doc"
:
doc
,
"target"
:
target
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
],
"filter"
:
filter_key
,
"metrics"
:
metrics
,
"doc_hash"
:
hash_string
(
json
.
dumps
(
requests
[
0
].
doc
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
,
)
),
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
[
0
]),
"target_hash"
:
hash_string
(
str
(
target
)),
}
example
.
update
({
"metrics"
:
metrics
})
task_output
.
logged_samples
.
append
(
example
)
for
x
in
metrics
:
for
metric
,
value
in
x
.
items
():
task_output
.
sample_metrics
[(
metric
,
filter_key
)].
append
(
value
)
if
WORLD_SIZE
>
1
:
if
WORLD_SIZE
>
1
:
# if multigpu, then gather data across all ranks to rank 0
# if multigpu, then gather data across all ranks to rank 0
...
@@ -756,8 +641,8 @@ def evaluate(
...
@@ -756,8 +641,8 @@ def evaluate(
if
RANK
==
0
:
if
RANK
==
0
:
### Aggregate results over all datapoints ###
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
# aggregate results ; run bootstrap CIs
for
task_output
in
eval_tasks
:
#
for task_output in eval_tasks:
task_output
.
calculate_aggregate_metric
(
bootstrap_iters
=
bootstrap_iters
)
#
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
(
(
results
,
results
,
samples
,
samples
,
...
...
lm_eval/filters/selection.py
View file @
9093b1a6
...
@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
...
@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
"""
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
"""
return
map
(
lambda
r
:
r
[
0
]
,
resps
)
return
map
(
lambda
r
:
r
,
resps
)
@
register_filter
(
"take_first_k"
)
@
register_filter
(
"take_first_k"
)
...
...
lm_eval/utils.py
View file @
9093b1a6
...
@@ -17,6 +17,8 @@ import numpy as np
...
@@ -17,6 +17,8 @@ import numpy as np
import
yaml
import
yaml
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
lm_eval.api.instance
import
Instance
SPACING
=
" "
*
47
SPACING
=
" "
*
47
...
@@ -406,9 +408,13 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
...
@@ -406,9 +408,13 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
v
=
"%.4f"
%
v
if
isinstance
(
v
,
float
)
else
v
v
=
"%.4f"
%
v
if
isinstance
(
v
,
float
)
else
v
if
m
+
"_stderr"
+
","
+
f
in
dic
:
if
m
+
"_stderr"
+
","
+
f
in
dic
:
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
try
:
se
=
" N/A"
if
se
==
"N/A"
else
"%.4f"
%
se
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
"±"
,
se
])
se
=
" N/A"
if
se
==
"N/A"
else
"%.4f"
%
se
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
"±"
,
se
])
except
:
# noqa: E722
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
""
,
""
])
else
:
else
:
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
""
,
""
])
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
""
,
""
])
k
=
""
k
=
""
...
@@ -550,3 +556,44 @@ def weighted_f1_score(items):
...
@@ -550,3 +556,44 @@ def weighted_f1_score(items):
preds
=
unzipped_list
[
1
]
preds
=
unzipped_list
[
1
]
fscore
=
f1_score
(
golds
,
preds
,
average
=
"weighted"
)
fscore
=
f1_score
(
golds
,
preds
,
average
=
"weighted"
)
return
fscore
return
fscore
def
create_sample_log
(
doc
:
dict
,
doc_id
:
int
,
target
:
Any
,
requests
:
list
[
Instance
],
metric_names
:
[
dict
],
filter_key
:
str
,
)
->
dict
:
return
{
"doc_id"
:
doc_id
,
"doc"
:
doc
,
"target"
:
target
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
],
"filter"
:
filter_key
,
"metrics"
:
metric_names
,
"doc_hash"
:
hash_string
(
json
.
dumps
(
requests
[
0
].
doc
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
,
)
),
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
[
0
]),
"target_hash"
:
hash_string
(
str
(
target
)),
}
def
pass_at_k
(
n
:
int
,
c
:
int
,
k
:
int
)
->
float
:
"""
:param n: total number of samples
:param c: number of correct samples
:param k: k in pass@$k$
"""
if
n
-
c
<
k
:
return
1.0
return
1.0
-
np
.
prod
(
1.0
-
k
/
np
.
arange
(
n
-
c
+
1
,
n
+
1
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment