Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
80d0f412
"doc/en/vscode:/vscode.git/clone" did not exist on "c9eb1f76f644ec8a13e0f475889607ba35c3580b"
Commit
80d0f412
authored
Jun 10, 2024
by
lintangsutawika
Browse files
change how aggregate_metric is loaded
parent
0f095f79
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
28 deletions
+62
-28
lm_eval/api/task.py
lm_eval/api/task.py
+25
-5
lm_eval/evaluator.py
lm_eval/evaluator.py
+37
-23
No files found.
lm_eval/api/task.py
View file @
80d0f412
...
@@ -5,7 +5,7 @@ import random
...
@@ -5,7 +5,7 @@ import random
import
re
import
re
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
copy
import
deepcopy
from
copy
import
deepcopy
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
,
field
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
(
from
typing
import
(
Any
,
Any
,
...
@@ -51,6 +51,17 @@ ALL_OUTPUT_TYPES = [
...
@@ -51,6 +51,17 @@ ALL_OUTPUT_TYPES = [
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
@
dataclass
class
AggMetricConfig
(
dict
):
metric
:
Optional
[
str
]
=
"acc"
metric_alias
:
Optional
[
str
]
=
"acc"
aggregation
:
Optional
[
str
]
=
"mean"
weight_by_size
:
Optional
[
str
]
=
False
filter_list
:
Optional
[
Union
[
str
,
list
]]
=
"none"
def
__post_init__
(
self
):
if
isinstance
(
self
.
filter_list
,
str
):
self
.
filter_list
=
[
self
.
filter_list
]
@
dataclass
@
dataclass
class
GroupConfig
(
dict
):
class
GroupConfig
(
dict
):
...
@@ -58,10 +69,9 @@ class GroupConfig(dict):
...
@@ -58,10 +69,9 @@ class GroupConfig(dict):
group_alias
:
Optional
[
str
]
=
None
group_alias
:
Optional
[
str
]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
tag_to_task
:
Optional
[
str
]
=
False
tag_to_task
:
Optional
[
str
]
=
False
aggregate_metric
:
Optional
[
str
]
=
False
aggregate_metric_list
:
Optional
[
aggregate_fn
:
Optional
[
str
]
=
"mean"
Union
[
List
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
weight_by_size
:
Optional
[
str
]
=
False
]
=
None
metric_alias
:
Optional
[
str
]
=
None
# Still a placeholder
metadata
:
Optional
[
metadata
:
Optional
[
dict
dict
]
=
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
]
=
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
...
@@ -72,6 +82,16 @@ class GroupConfig(dict):
...
@@ -72,6 +82,16 @@ class GroupConfig(dict):
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
return
setattr
(
self
,
item
,
value
)
def
__post_init__
(
self
):
if
self
.
aggregate_metric_list
is
not
None
:
if
isinstance
(
self
.
aggregate_metric_list
,
dict
):
self
.
aggregate_metric_list
=
[
self
.
aggregate_metric_list
]
self
.
aggregate_metric_list
=
[
AggMetricConfig
(
**
item
)
if
isinstance
(
item
,
dict
)
else
item
for
item
in
self
.
aggregate_metric_list
]
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
null fields will not be printed.
...
...
lm_eval/evaluator.py
View file @
80d0f412
...
@@ -616,13 +616,16 @@ def evaluate(
...
@@ -616,13 +616,16 @@ def evaluate(
)
)
if
(
group_config
is
None
)
or
(
if
(
group_config
is
None
)
or
(
group_config
[
"aggregate_metric"
]
is
Fals
e
group_config
[
"aggregate_metric"
]
is
Non
e
):
):
results
[
group_or_task
][
" "
]
=
" "
results
[
group_or_task
][
" "
]
=
" "
continue
continue
show_group_table
=
(
if
"aggregate_metric"
in
group_config
:
show_group_table
|
group_config
[
"aggregate_metric"
]
agg_metric_list
=
group_config
[
"aggregate_metric"
]
show_group_table
=
show_group_table
|
bool
(
group_config
[
"aggregate_metric"
]
)
)
task_list
=
_task_aggregation_list
[
group_or_task
]
task_list
=
_task_aggregation_list
[
group_or_task
]
...
@@ -656,13 +659,23 @@ def evaluate(
...
@@ -656,13 +659,23 @@ def evaluate(
if
metric
in
results
[
task
]
if
metric
in
results
[
task
]
]
]
for
metric_config
in
agg_metric_list
:
for
filter
in
metric_config
[
"filter_list"
]:
if
metric
!=
","
.
join
([
metric_config
[
"metric"
],
filter
]):
continue
# compute group's pooled metric and stderr
# compute group's pooled metric and stderr
results
[
group_or_task
][
if
metric_config
[
"aggregation"
]
==
"mean"
:
metric
aggregate_fn
=
(
]
=
lm_eval
.
api
.
metrics
.
aggregate_subtask_metrics
(
lm_eval
.
api
.
metrics
.
aggregate_subtask_metrics
)
else
:
aggregate_fn
=
metric_config
[
"aggregation"
]
results
[
group_or_task
][
metric
]
=
aggregate_fn
(
metrics
,
metrics
,
sizes
,
sizes
,
group
_config
[
"weight_by_size"
],
metric
_config
[
"weight_by_size"
],
)
)
# TODO: calculate grouped metric using aggregation fn
# TODO: calculate grouped metric using aggregation fn
if
"N/A"
in
stderrs
:
if
"N/A"
in
stderrs
:
...
@@ -683,6 +696,7 @@ def evaluate(
...
@@ -683,6 +696,7 @@ def evaluate(
versions
[
group_or_task
]
=
group_metadata
.
get
(
versions
[
group_or_task
]
=
group_metadata
.
get
(
"version"
,
None
"version"
,
None
)
)
# print(results)
return
results
,
versions
,
show_group_table
,
task_aggregation_list
return
results
,
versions
,
show_group_table
,
task_aggregation_list
results
,
versions
,
show_group_table
,
*
_
=
process_group
(
results
,
versions
,
show_group_table
,
*
_
=
process_group
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment