Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
80d0f412
Commit
80d0f412
authored
Jun 10, 2024
by
lintangsutawika
Browse files
change how aggregate_metric is loaded
parent
0f095f79
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
28 deletions
+62
-28
lm_eval/api/task.py
lm_eval/api/task.py
+25
-5
lm_eval/evaluator.py
lm_eval/evaluator.py
+37
-23
No files found.
lm_eval/api/task.py
View file @
80d0f412
...
...
@@ -5,7 +5,7 @@ import random
import
re
from
collections.abc
import
Callable
from
copy
import
deepcopy
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
,
field
from
inspect
import
getsource
from
typing
import
(
Any
,
...
...
@@ -51,6 +51,17 @@ ALL_OUTPUT_TYPES = [
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
@
dataclass
class
AggMetricConfig
(
dict
):
metric
:
Optional
[
str
]
=
"acc"
metric_alias
:
Optional
[
str
]
=
"acc"
aggregation
:
Optional
[
str
]
=
"mean"
weight_by_size
:
Optional
[
str
]
=
False
filter_list
:
Optional
[
Union
[
str
,
list
]]
=
"none"
def
__post_init__
(
self
):
if
isinstance
(
self
.
filter_list
,
str
):
self
.
filter_list
=
[
self
.
filter_list
]
@
dataclass
class
GroupConfig
(
dict
):
...
...
@@ -58,10 +69,9 @@ class GroupConfig(dict):
group_alias
:
Optional
[
str
]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
tag_to_task
:
Optional
[
str
]
=
False
aggregate_metric
:
Optional
[
str
]
=
False
aggregate_fn
:
Optional
[
str
]
=
"mean"
weight_by_size
:
Optional
[
str
]
=
False
metric_alias
:
Optional
[
str
]
=
None
# Still a placeholder
aggregate_metric_list
:
Optional
[
Union
[
List
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
]
=
None
metadata
:
Optional
[
dict
]
=
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
...
...
@@ -72,6 +82,16 @@ class GroupConfig(dict):
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
def
__post_init__
(
self
):
if
self
.
aggregate_metric_list
is
not
None
:
if
isinstance
(
self
.
aggregate_metric_list
,
dict
):
self
.
aggregate_metric_list
=
[
self
.
aggregate_metric_list
]
self
.
aggregate_metric_list
=
[
AggMetricConfig
(
**
item
)
if
isinstance
(
item
,
dict
)
else
item
for
item
in
self
.
aggregate_metric_list
]
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
...
...
lm_eval/evaluator.py
View file @
80d0f412
...
...
@@ -616,13 +616,16 @@ def evaluate(
)
if
(
group_config
is
None
)
or
(
group_config
[
"aggregate_metric"
]
is
Fals
e
group_config
[
"aggregate_metric"
]
is
Non
e
):
results
[
group_or_task
][
" "
]
=
" "
continue
show_group_table
=
(
show_group_table
|
group_config
[
"aggregate_metric"
]
if
"aggregate_metric"
in
group_config
:
agg_metric_list
=
group_config
[
"aggregate_metric"
]
show_group_table
=
show_group_table
|
bool
(
group_config
[
"aggregate_metric"
]
)
task_list
=
_task_aggregation_list
[
group_or_task
]
...
...
@@ -656,26 +659,36 @@ def evaluate(
if
metric
in
results
[
task
]
]
# compute group's pooled metric and stderr
results
[
group_or_task
][
metric
]
=
lm_eval
.
api
.
metrics
.
aggregate_subtask_metrics
(
metrics
,
sizes
,
group_config
[
"weight_by_size"
],
)
# TODO: calculate grouped metric using aggregation fn
if
"N/A"
in
stderrs
:
results
[
group_or_task
][
stderr
]
=
"N/A"
else
:
results
[
group_or_task
][
stderr
]
=
lm_eval
.
api
.
metrics
.
pooled_sample_stderr
(
stderrs
,
sizes
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
for
metric_config
in
agg_metric_list
:
for
filter
in
metric_config
[
"filter_list"
]:
if
metric
!=
","
.
join
([
metric_config
[
"metric"
],
filter
]):
continue
# compute group's pooled metric and stderr
if
metric_config
[
"aggregation"
]
==
"mean"
:
aggregate_fn
=
(
lm_eval
.
api
.
metrics
.
aggregate_subtask_metrics
)
else
:
aggregate_fn
=
metric_config
[
"aggregation"
]
results
[
group_or_task
][
metric
]
=
aggregate_fn
(
metrics
,
sizes
,
metric_config
[
"weight_by_size"
],
)
# TODO: calculate grouped metric using aggregation fn
if
"N/A"
in
stderrs
:
results
[
group_or_task
][
stderr
]
=
"N/A"
else
:
results
[
group_or_task
][
stderr
]
=
lm_eval
.
api
.
metrics
.
pooled_sample_stderr
(
stderrs
,
sizes
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results
[
group_or_task
][
"samples"
]
=
sum
(
sizes
)
group_metadata
=
group_config
.
get
(
"metadata"
,
None
)
...
...
@@ -683,6 +696,7 @@ def evaluate(
versions
[
group_or_task
]
=
group_metadata
.
get
(
"version"
,
None
)
# print(results)
return
results
,
versions
,
show_group_table
,
task_aggregation_list
results
,
versions
,
show_group_table
,
*
_
=
process_group
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment