Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
88fea8ad
Commit
88fea8ad
authored
May 16, 2024
by
lintangsutawika
Browse files
add get_subtask_list function to get proper subtask list
parent
c90655d5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
15 deletions
+67
-15
lm_eval/evaluator.py
lm_eval/evaluator.py
+16
-13
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+51
-2
No files found.
lm_eval/evaluator.py
View file @
88fea8ad
...
...
@@ -17,6 +17,7 @@ from lm_eval.caching.cache import delete_cache
from
lm_eval.evaluator_utils
import
(
consolidate_results
,
get_sample_size
,
get_subtask_list
,
get_task_list
,
prepare_print_tasks
,
print_writeout
,
...
...
@@ -531,14 +532,14 @@ def evaluate(
versions
,
task_dict
,
task_root
=
None
,
task_hierarchy
=
None
,
show_group_table
=
False
,
task_aggregation_list
=
None
,
):
if
task_root
is
None
:
task_root
=
{}
if
task_
hierarchy
is
None
:
task_
hierarchy
=
{}
if
task_
aggregation_list
is
None
:
task_
aggregation_list
=
{}
for
group_or_task
,
group_or_task_info
in
task_dict
.
items
():
# Convert to string
...
...
@@ -550,26 +551,26 @@ def evaluate(
if
isinstance
(
group_or_task_info
,
ConfigurableTask
):
if
task_root
:
task_
hierarchy
.
setdefault
(
task_root
,
[]).
append
(
task_
aggregation_list
.
setdefault
(
task_root
,
[]).
append
(
group_or_task_info
.
task_id
)
else
:
(
results
,
versions
,
_task_hierarchy
,
show_group_table
,
_task_aggregation_list
,
)
=
process_group
(
results
,
versions
,
group_or_task_info
,
group_or_task
,
task_hierarchy
,
show_group_table
,
task_aggregation_list
,
)
if
task_root
:
task_
hierarchy
.
setdefault
(
task_root
,
[]).
extend
(
task_
hierarchy
.
get
(
group_or_task
,
[])
task_
aggregation_list
.
setdefault
(
task_root
,
[]).
extend
(
task_
aggregation_list
.
get
(
group_or_task
,
[])
)
if
(
group_config
is
None
)
or
(
...
...
@@ -582,14 +583,14 @@ def evaluate(
show_group_table
|
group_config
[
"aggregate_metric"
]
)
task_list
=
_task_
hierarchy
[
group_or_task
]
task_list
=
_task_
aggregation_list
[
group_or_task
]
metric_list
=
list
(
{
key
for
task
in
task_list
for
key
in
results
[
task
].
keys
()
if
"_stderr"
not
in
key
and
key
not
in
[
"alias"
,
"samples"
]
and
key
not
in
[
"task"
,
"alias"
,
"samples"
]
}
)
for
metric
in
metric_list
:
...
...
@@ -635,13 +636,15 @@ def evaluate(
results
[
group_or_task
][
"samples"
]
=
sum
(
sizes
)
versions
[
group_or_task
]
=
group_config
[
"version"
]
return
results
,
versions
,
task_hierarchy
,
show_group_table
return
results
,
versions
,
show_group_table
,
task_aggregation_list
results
,
versions
,
task_hierarchy
,
show_group_table
=
process_group
(
results
,
versions
,
show_group_table
,
*
_
=
process_group
(
results
,
versions
,
task_dict
)
results_agg
,
group_agg
=
prepare_print_tasks
(
task_dict
,
results
)
subtask_list
=
get_subtask_list
(
task_dict
)
results_dict
=
{
"results"
:
dict
(
results_agg
.
items
()),
**
(
...
...
@@ -649,7 +652,7 @@ def evaluate(
if
(
bool
(
group_agg
)
&
show_group_table
)
else
{}
),
"group_subtasks"
:
dict
(
reversed
(
task_
hierarchy
.
items
())),
"group_subtasks"
:
dict
(
reversed
(
sub
task_
list
.
items
())),
"configs"
:
dict
(
sorted
(
configs
.
items
())),
"versions"
:
dict
(
sorted
(
versions
.
items
())),
"n-shot"
:
dict
(
sorted
(
num_fewshot
.
items
())),
...
...
lm_eval/evaluator_utils.py
View file @
88fea8ad
...
...
@@ -137,6 +137,50 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
return
outputs
def
get_subtask_list
(
task_dict
,
task_root
=
None
,
depth
=
0
):
subtask_list
=
{}
for
group_obj
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
group_obj
,
ConfigurableGroup
):
group_name
=
group_obj
.
group_name
else
:
group_name
=
group_obj
if
isinstance
(
task_obj
,
dict
):
_subtask_list
=
get_subtask_list
(
task_obj
,
task_root
=
group_name
,
depth
=
depth
+
1
)
if
task_root
:
subtask_list
.
setdefault
((
task_root
,
depth
),
[]).
extend
(
[
_task
for
(
_task
,
_depth
)
in
_subtask_list
.
keys
()
if
(
_depth
-
1
)
==
depth
]
)
subtask_list
=
{
**
subtask_list
,
**
_subtask_list
}
else
:
if
isinstance
(
task_obj
,
ConfigurableGroup
):
group_or_task_name
=
task_obj
.
group_name
elif
isinstance
(
task_obj
,
ConfigurableTask
):
group_or_task_name
=
task_obj
.
task_name
if
task_root
is
None
:
subtask_list
.
setdefault
((
group_or_task_name
,
depth
),
[])
else
:
subtask_list
.
setdefault
((
task_root
,
depth
),
[]).
append
(
group_or_task_name
)
if
depth
==
0
:
_subtask_list
=
{}
for
group_key
,
task_list
in
subtask_list
.
items
():
group_name
,
depth
=
group_key
_subtask_list
[
group_name
]
=
task_list
subtask_list
=
_subtask_list
return
subtask_list
def
print_writeout
(
task
)
->
None
:
for
inst
in
task
.
instances
:
# print the prompt for the first few documents
...
...
@@ -181,16 +225,20 @@ def prepare_print_tasks(
for
task_or_group_name
,
task_or_group_obj
in
task_dict
.
items
():
tab_string
=
" "
*
task_depth
+
"- "
if
task_depth
>
0
else
""
if
isinstance
(
task_or_group_name
,
ConfigurableGroup
):
#
name = task_or_group_name.group
string_
name
=
task_or_group_name
.
group
_name
name
=
task_or_group_name
.
task_id
from_configurable_group
=
True
elif
isinstance
(
task_or_group_name
,
str
):
name
=
task_or_group_name
if
isinstance
(
task_or_group_obj
,
ConfigurableTask
):
string_name
=
task_or_group_obj
.
task_name
name
=
task_or_group_obj
.
task_id
from_configurable_group
=
False
task_agg
[
name
]
=
results
[
name
].
copy
()
task_agg
[
name
]
=
{
**
{
"task_or_group_name"
:
string_name
},
**
results
[
name
].
copy
(),
}
if
from_configurable_group
:
if
task_or_group_name
.
group_alias
is
not
None
:
alias
=
task_or_group_name
.
group_alias
...
...
@@ -262,6 +310,7 @@ def consolidate_results(
# Tracks each task's version.
versions
=
collections
.
defaultdict
(
dict
)
for
task_output
in
eval_tasks
:
# results[task_output.task_id]["task"] = task_output.task_name
if
"task_alias"
in
(
task_config
:
=
task_output
.
task_config
):
results
[
task_output
.
task_id
][
"alias"
]
=
task_config
[
"task_alias"
]
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment