Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ad70d206
Commit
ad70d206
authored
May 07, 2024
by
lintangsutawika
Browse files
update to work with new group and task configuration
parent
c23c9305
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
123 additions
and
78 deletions
+123
-78
lm_eval/api/task.py
lm_eval/api/task.py
+9
-4
lm_eval/evaluator.py
lm_eval/evaluator.py
+99
-67
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+3
-4
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-1
lm_eval/tasks/mmlu/default/_mmlu.yaml
lm_eval/tasks/mmlu/default/_mmlu.yaml
+1
-0
lm_eval/tasks/mmlu/default/_mmlu_humanities.yaml
lm_eval/tasks/mmlu/default/_mmlu_humanities.yaml
+1
-0
lm_eval/tasks/mmlu/default/_mmlu_other.yaml
lm_eval/tasks/mmlu/default/_mmlu_other.yaml
+1
-0
lm_eval/tasks/mmlu/default/_mmlu_social_sciences.yaml
lm_eval/tasks/mmlu/default/_mmlu_social_sciences.yaml
+1
-0
lm_eval/tasks/mmlu/default/_mmlu_stem.yaml
lm_eval/tasks/mmlu/default/_mmlu_stem.yaml
+1
-0
lm_eval/utils.py
lm_eval/utils.py
+5
-2
No files found.
lm_eval/api/task.py
View file @
ad70d206
...
@@ -60,6 +60,7 @@ class GroupConfig(dict):
...
@@ -60,6 +60,7 @@ class GroupConfig(dict):
aggregate_fn
:
Optional
[
str
]
=
"mean"
aggregate_fn
:
Optional
[
str
]
=
"mean"
weight_by_size
:
Optional
[
str
]
=
False
weight_by_size
:
Optional
[
str
]
=
False
metric_alias
:
Optional
[
str
]
=
None
metric_alias
:
Optional
[
str
]
=
None
version
:
Optional
[
str
]
=
0
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
...
@@ -118,16 +119,20 @@ class ConfigurableGroup(abc.ABC):
...
@@ -118,16 +119,20 @@ class ConfigurableGroup(abc.ABC):
def
group_alias
(
self
):
def
group_alias
(
self
):
return
self
.
_config
.
group_alias
return
self
.
_config
.
group_alias
@
property
def
version
(
self
):
return
self
.
_config
.
version
@
property
@
property
def
config
(
self
):
def
config
(
self
):
return
self
.
_config
.
to_dict
()
return
self
.
_config
.
to_dict
()
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
return
(
f
"ConfigurableGroup(group=
{
self
.
group
}
,"
f
"ConfigurableGroup(group=
{
self
.
group
}
,"
f
"group_alias=
{
self
.
group_alias
}
)"
f
"group_alias=
{
self
.
group_alias
}
)"
)
)
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
...
...
lm_eval/evaluator.py
View file @
ad70d206
...
@@ -17,12 +17,16 @@ from lm_eval.evaluator_utils import (
...
@@ -17,12 +17,16 @@ from lm_eval.evaluator_utils import (
consolidate_results
,
consolidate_results
,
get_sample_size
,
get_sample_size
,
get_task_list
,
get_task_list
,
prepare_print_tasks
,
print_writeout
,
print_writeout
,
run_task_tests
,
run_task_tests
,
)
)
from
lm_eval.logging_utils
import
add_env_info
,
get_git_commit_hash
from
lm_eval.logging_utils
import
add_env_info
,
get_git_commit_hash
from
lm_eval.tasks
import
ConfigurableGroup
,
ConfigurableTask
,
TaskManager
,
get_task_dict
from
lm_eval.tasks
import
(
ConfigurableGroup
,
ConfigurableTask
,
TaskManager
,
get_task_dict
,
)
from
lm_eval.utils
import
eval_logger
,
positional_deprecated
,
simple_parse_args_string
from
lm_eval.utils
import
eval_logger
,
positional_deprecated
,
simple_parse_args_string
...
@@ -211,14 +215,14 @@ def simple_evaluate(
...
@@ -211,14 +215,14 @@ def simple_evaluate(
task_manager
=
TaskManager
(
verbosity
)
task_manager
=
TaskManager
(
verbosity
)
task_dict
=
get_task_dict
(
tasks
,
task_manager
)
task_dict
=
get_task_dict
(
tasks
,
task_manager
)
def
_adjust_config
(
task_dict
):
def
_adjust_config
(
task_dict
):
adjusted_task_dict
=
{}
adjusted_task_dict
=
{}
for
task_name
,
task_obj
in
task_dict
.
items
():
for
task_name
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
task_obj
,
dict
):
if
isinstance
(
task_obj
,
dict
):
adjusted_task_dict
=
{
adjusted_task_dict
=
{
**
adjusted_task_dict
,
**
adjusted_task_dict
,
**
{
task_name
:
_adjust_config
(
task_obj
)}
**
{
task_name
:
_adjust_config
(
task_obj
)}
,
}
}
else
:
else
:
...
@@ -229,7 +233,6 @@ def simple_evaluate(
...
@@ -229,7 +233,6 @@ def simple_evaluate(
)
)
if
predict_only
:
if
predict_only
:
log_samples
=
True
eval_logger
.
info
(
eval_logger
.
info
(
f
"Processing
{
task_name
}
in output-only mode. Metrics will not be calculated!"
f
"Processing
{
task_name
}
in output-only mode. Metrics will not be calculated!"
)
)
...
@@ -250,7 +253,9 @@ def simple_evaluate(
...
@@ -250,7 +253,9 @@ def simple_evaluate(
task_obj
.
set_config
(
key
=
"num_fewshot"
,
value
=
num_fewshot
)
task_obj
.
set_config
(
key
=
"num_fewshot"
,
value
=
num_fewshot
)
else
:
else
:
# if num_fewshot not provided, and the task does not define a default one, default to 0
# if num_fewshot not provided, and the task does not define a default one, default to 0
if
(
default_num_fewshot
:
=
task_obj
.
get_config
(
"num_fewshot"
))
is
None
:
if
(
default_num_fewshot
:
=
task_obj
.
get_config
(
"num_fewshot"
)
)
is
None
:
task_obj
.
set_config
(
key
=
"num_fewshot"
,
value
=
0
)
task_obj
.
set_config
(
key
=
"num_fewshot"
,
value
=
0
)
adjusted_task_dict
[
task_name
]
=
task_obj
adjusted_task_dict
[
task_name
]
=
task_obj
...
@@ -266,7 +271,7 @@ def simple_evaluate(
...
@@ -266,7 +271,7 @@ def simple_evaluate(
rewrite_requests_cache
=
rewrite_requests_cache
,
rewrite_requests_cache
=
rewrite_requests_cache
,
bootstrap_iters
=
bootstrap_iters
,
bootstrap_iters
=
bootstrap_iters
,
write_out
=
write_out
,
write_out
=
write_out
,
log_samples
=
log_samples
,
log_samples
=
True
if
predict_only
else
log_samples
,
verbosity
=
verbosity
,
verbosity
=
verbosity
,
)
)
...
@@ -340,9 +345,6 @@ def evaluate(
...
@@ -340,9 +345,6 @@ def evaluate(
# get lists of group hierarchy and each type of request
# get lists of group hierarchy and each type of request
eval_tasks
=
get_task_list
(
task_dict
)
eval_tasks
=
get_task_list
(
task_dict
)
# print("task_hierarchy")
# print(task_hierarchy)
# import sys; sys.exit()
if
not
log_samples
:
if
not
log_samples
:
if
not
all
(
if
not
all
(
"bypass"
not
in
getattr
(
task_output
.
task
,
"_metric_fn_list"
,
{}).
keys
()
"bypass"
not
in
getattr
(
task_output
.
task
,
"_metric_fn_list"
,
{}).
keys
()
...
@@ -496,8 +498,14 @@ def evaluate(
...
@@ -496,8 +498,14 @@ def evaluate(
### Calculate group metrics ###
### Calculate group metrics ###
if
bool
(
results
):
if
bool
(
results
):
def
process_group
(
results
,
task_dict
,
task_root
=
None
,
task_hierarchy
=
None
,
show_group_table
=
False
):
def
process_group
(
results
,
task_dict
,
task_root
=
None
,
task_hierarchy
=
None
,
show_group_table
=
False
,
):
if
task_root
is
None
:
if
task_root
is
None
:
task_root
=
{}
task_root
=
{}
...
@@ -505,33 +513,42 @@ def evaluate(
...
@@ -505,33 +513,42 @@ def evaluate(
task_hierarchy
=
{}
task_hierarchy
=
{}
for
group_or_task
,
group_or_task_info
in
task_dict
.
items
():
for
group_or_task
,
group_or_task_info
in
task_dict
.
items
():
if
isinstance
(
group_or_task_info
,
ConfigurableTask
):
if
task_root
:
task_hierarchy
.
setdefault
(
task_root
,
[]).
append
(
group_or_task
)
else
:
results
,
_task_hierarchy
,
show_group_table
=
process_group
(
results
,
group_or_task_info
,
group_or_task
,
task_hierarchy
,
show_group_table
)
if
task_root
:
task_hierarchy
.
setdefault
(
task_root
,
[]).
extend
(
task_hierarchy
.
get
(
group_or_task
,
[]))
if
isinstance
(
group_or_task
,
ConfigurableGroup
):
if
isinstance
(
group_or_task
,
ConfigurableGroup
):
group_config
=
group_or_task
.
config
group_config
=
group_or_task
.
config
group
=
group_or_task
.
group
group_or_task
=
group_or_task
.
group
show_group_table
=
show_group_table
|
group_config
[
"aggregate_metric"
]
show_group_table
=
(
show_group_table
|
group_config
[
"aggregate_metric"
]
)
if
group_config
[
"aggregate_metric"
]
is
False
:
if
group_config
[
"aggregate_metric"
]
is
False
:
results
[
group
][
" "
]
=
" "
continue
elif
isinstance
(
group_or_task
,
str
):
results
[
group_or_task
][
" "
]
=
" "
results
[
group_or_task
][
" "
]
=
" "
continue
continue
if
isinstance
(
group_or_task_info
,
ConfigurableTask
):
if
task_root
:
task_hierarchy
.
setdefault
(
task_root
,
[]).
append
(
group_or_task
)
else
:
results
,
_task_hierarchy
,
show_group_table
=
process_group
(
results
,
group_or_task_info
,
group_or_task
,
task_hierarchy
,
show_group_table
,
)
if
task_root
:
task_hierarchy
.
setdefault
(
task_root
,
[]).
extend
(
task_hierarchy
.
get
(
group_or_task
,
[])
)
task_list
=
_task_hierarchy
[
group_or_task
]
task_list
=
_task_hierarchy
[
group_or_task
]
metric_list
=
list
(
metric_list
=
list
(
{
{
key
key
for
task
in
task_list
for
task
in
task_list
for
key
in
results
[
task
].
keys
()
for
key
in
results
[
task
].
keys
()
if
"_stderr"
not
in
key
and
key
not
in
[
"alias"
,
"samples"
]
if
"_stderr"
not
in
key
and
key
not
in
[
"alias"
,
"samples"
]
}
}
)
)
for
metric
in
metric_list
:
for
metric
in
metric_list
:
...
@@ -555,7 +572,7 @@ def evaluate(
...
@@ -555,7 +572,7 @@ def evaluate(
]
]
# compute group's pooled metric and stderr
# compute group's pooled metric and stderr
results
[
group
][
results
[
group
_or_task
][
metric
metric
]
=
lm_eval
.
api
.
metrics
.
aggregate_subtask_metrics
(
]
=
lm_eval
.
api
.
metrics
.
aggregate_subtask_metrics
(
metrics
,
metrics
,
...
@@ -564,54 +581,69 @@ def evaluate(
...
@@ -564,54 +581,69 @@ def evaluate(
)
)
# TODO: calculate grouped metric using aggregation fn
# TODO: calculate grouped metric using aggregation fn
if
"N/A"
in
stderrs
:
if
"N/A"
in
stderrs
:
results
[
group
][
stderr
]
=
"N/A"
results
[
group
_or_task
][
stderr
]
=
"N/A"
else
:
else
:
results
[
group
][
results
[
group
_or_task
][
stderr
stderr
]
=
lm_eval
.
api
.
metrics
.
pooled_sample_stderr
(
stderrs
,
sizes
)
]
=
lm_eval
.
api
.
metrics
.
pooled_sample_stderr
(
stderrs
,
sizes
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results
[
group
][
"samples"
]
=
sum
(
sizes
)
results
[
group
_or_task
][
"samples"
]
=
sum
(
sizes
)
return
results
,
task_hierarchy
,
show_group_table
return
results
,
task_hierarchy
,
show_group_table
results
,
task_hierarchy
,
show_group_table
=
process_group
(
results
,
task_dict
)
results
,
task_hierarchy
,
show_group_table
=
process_group
(
results
,
task_dict
print
(
task_hierarchy
)
)
import
sys
;
sys
.
exit
()
results_agg
=
defaultdict
(
dict
)
groups_agg
=
defaultdict
(
dict
)
all_tasks_list
=
list
(
task_hierarchy
.
keys
())
while
True
:
add_tasks_list
=
list
(
k
for
k
in
results_agg
.
keys
())
left_tasks_list
=
sorted
(
list
(
set
(
all_tasks_list
)
-
set
(
add_tasks_list
)))
if
len
(
left_tasks_list
)
==
0
:
break
_task_hierarchy
=
{
k
:
v
[
"tasks"
]
for
k
,
v
in
task_hierarchy
.
items
()
if
k
in
left_tasks_list
}
_results_agg
,
_groups_agg
=
prepare_print_tasks
(
_task_hierarchy
,
results
)
results_agg
=
{
**
results_agg
,
**
_results_agg
}
def
print_table
(
task_dict
,
results
,
task_depth
=
0
):
groups_agg
=
{
**
groups_agg
,
**
_groups_agg
}
task_agg
=
defaultdict
(
dict
)
for
task_or_group_name
,
task_or_group_obj
in
task_dict
.
items
():
tab_string
=
" "
*
task_depth
+
"- "
if
task_depth
>
0
else
""
if
isinstance
(
task_or_group_name
,
ConfigurableGroup
):
name
=
task_or_group_name
.
group
from_configurable_group
=
True
elif
isinstance
(
task_or_group_name
,
str
):
name
=
task_or_group_name
from_configurable_group
=
False
task_agg
[
name
]
=
results
[
name
].
copy
()
if
from_configurable_group
:
if
task_or_group_name
.
group_alias
is
not
None
:
alias
=
task_or_group_name
.
group_alias
else
:
alias
=
name
else
:
if
"alias"
in
task_agg
[
name
]:
alias
=
task_agg
[
name
][
"alias"
]
else
:
alias
=
name
for
group_name
,
group_info
in
task_hierarchy
.
items
():
task_agg
[
name
][
"alias"
]
=
tab_string
+
alias
task_list
=
group_info
[
"tasks"
]
if
"samples"
in
task_agg
[
name
]:
if
task_list
:
task_agg
[
name
].
pop
(
"samples"
)
num_fewshot
[
group_name
]
=
num_fewshot
[
task_list
[
0
]
if
isinstance
(
task_or_group_obj
,
dict
):
]
# TODO: validate this
task_depth
+=
1
task_agg
=
{
**
task_agg
,
**
print_table
(
task_or_group_obj
,
results
,
task_depth
),
}
task_depth
-=
1
return
task_agg
results_agg
=
print_table
(
task_dict
,
results
)
import
sys
;
sys
.
exit
()
results_dict
=
{
results_dict
=
{
"results"
:
dict
(
results_agg
.
items
()),
"results"
:
dict
(
results_agg
.
items
()),
**
(
#
**(
{
"groups"
:
dict
(
groups_agg
.
items
())}
#
{"groups": dict(groups_agg.items())}
if
(
bool
(
groups_agg
)
&
show_group_table
)
#
if (bool(groups_agg) & show_group_table)
else
{}
#
else {}
),
#
),
"group_subtasks"
:
dict
(
reversed
(
task_hierarchy
.
items
())),
"group_subtasks"
:
dict
(
reversed
(
task_hierarchy
.
items
())),
"configs"
:
dict
(
sorted
(
configs
.
items
())),
"configs"
:
dict
(
sorted
(
configs
.
items
())),
"versions"
:
dict
(
sorted
(
versions
.
items
())),
"versions"
:
dict
(
sorted
(
versions
.
items
())),
...
...
lm_eval/evaluator_utils.py
View file @
ad70d206
...
@@ -2,11 +2,11 @@ import collections
...
@@ -2,11 +2,11 @@ import collections
import
math
import
math
import
pathlib
import
pathlib
import
sys
import
sys
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lm_eval.api
import
metrics
from
lm_eval.api
import
metrics
from
lm_eval.utils
import
eval_logger
,
positional_deprecated
from
lm_eval.utils
import
eval_logger
,
positional_deprecated
from
lm_eval.api.task
import
ConfigurableTask
,
ConfigurableGroup
class
TaskOutput
:
class
TaskOutput
:
"""
"""
...
@@ -121,7 +121,6 @@ class TaskOutput:
...
@@ -121,7 +121,6 @@ class TaskOutput:
def
get_task_list
(
task_dict
:
dict
)
->
List
[
TaskOutput
]:
def
get_task_list
(
task_dict
:
dict
)
->
List
[
TaskOutput
]:
outputs
=
[]
outputs
=
[]
for
task_name
,
task_obj
in
task_dict
.
items
():
for
task_name
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
task_obj
,
dict
):
if
isinstance
(
task_obj
,
dict
):
_outputs
=
get_task_list
(
task_obj
)
_outputs
=
get_task_list
(
task_obj
)
outputs
.
extend
(
_outputs
)
outputs
.
extend
(
_outputs
)
...
...
lm_eval/tasks/__init__.py
View file @
ad70d206
...
@@ -5,7 +5,8 @@ from functools import partial
...
@@ -5,7 +5,8 @@ from functools import partial
from
typing
import
Dict
,
List
,
Mapping
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Mapping
,
Optional
,
Union
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.task
import
ConfigurableTask
,
ConfigurableGroup
,
GroupConfig
,
Task
from
lm_eval.api.task
import
ConfigurableGroup
,
ConfigurableTask
,
GroupConfig
,
Task
GROUP_ONLY_KEYS
=
list
(
GroupConfig
().
to_dict
().
keys
())
GROUP_ONLY_KEYS
=
list
(
GroupConfig
().
to_dict
().
keys
())
...
...
lm_eval/tasks/mmlu/default/_mmlu.yaml
View file @
ad70d206
...
@@ -6,3 +6,4 @@ task:
...
@@ -6,3 +6,4 @@ task:
-
mmlu_humanities
-
mmlu_humanities
aggregate_metric
:
True
aggregate_metric
:
True
weight_by_size
:
True
weight_by_size
:
True
version
:
1
lm_eval/tasks/mmlu/default/_mmlu_humanities.yaml
View file @
ad70d206
...
@@ -16,3 +16,4 @@ task:
...
@@ -16,3 +16,4 @@ task:
# - mmlu_world_religions
# - mmlu_world_religions
aggregate_metric
:
True
aggregate_metric
:
True
weight_by_size
:
True
weight_by_size
:
True
version
:
1
lm_eval/tasks/mmlu/default/_mmlu_other.yaml
View file @
ad70d206
...
@@ -16,3 +16,4 @@ task:
...
@@ -16,3 +16,4 @@ task:
# - mmlu_virology
# - mmlu_virology
aggregate_metric
:
True
aggregate_metric
:
True
weight_by_size
:
True
weight_by_size
:
True
version
:
1
lm_eval/tasks/mmlu/default/_mmlu_social_sciences.yaml
View file @
ad70d206
...
@@ -15,3 +15,4 @@ task:
...
@@ -15,3 +15,4 @@ task:
# - mmlu_us_foreign_policy
# - mmlu_us_foreign_policy
aggregate_metric
:
True
aggregate_metric
:
True
weight_by_size
:
True
weight_by_size
:
True
version
:
1
lm_eval/tasks/mmlu/default/_mmlu_stem.yaml
View file @
ad70d206
...
@@ -22,3 +22,4 @@ task:
...
@@ -22,3 +22,4 @@ task:
# - mmlu_machine_learning
# - mmlu_machine_learning
aggregate_metric
:
True
aggregate_metric
:
True
weight_by_size
:
True
weight_by_size
:
True
version
:
1
lm_eval/utils.py
View file @
ad70d206
...
@@ -242,8 +242,11 @@ def make_table(result_dict, column: str = "results"):
...
@@ -242,8 +242,11 @@ def make_table(result_dict, column: str = "results"):
values
=
[]
values
=
[]
for
k
,
dic
in
result_dict
[
column
].
items
():
for
k
,
dic
in
result_dict
[
column
].
items
():
version
=
result_dict
[
"versions"
].
get
(
k
,
"N/A"
)
version
=
result_dict
[
"versions"
].
get
(
k
,
" N/A"
)
if
k
in
result_dict
[
"n-shot"
]:
n
=
str
(
result_dict
[
"n-shot"
][
k
])
n
=
str
(
result_dict
[
"n-shot"
][
k
])
else
:
n
=
" "
if
"alias"
in
dic
:
if
"alias"
in
dic
:
k
=
dic
.
pop
(
"alias"
)
k
=
dic
.
pop
(
"alias"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment