Commit 88fea8ad authored by lintangsutawika's avatar lintangsutawika
Browse files

add get_subtask_list function to get proper subtask list

parent c90655d5
......@@ -17,6 +17,7 @@ from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import (
consolidate_results,
get_sample_size,
get_subtask_list,
get_task_list,
prepare_print_tasks,
print_writeout,
......@@ -531,14 +532,14 @@ def evaluate(
versions,
task_dict,
task_root=None,
task_hierarchy=None,
show_group_table=False,
task_aggregation_list=None,
):
if task_root is None:
task_root = {}
if task_hierarchy is None:
task_hierarchy = {}
if task_aggregation_list is None:
task_aggregation_list = {}
for group_or_task, group_or_task_info in task_dict.items():
# Convert to string
......@@ -550,26 +551,26 @@ def evaluate(
if isinstance(group_or_task_info, ConfigurableTask):
if task_root:
task_hierarchy.setdefault(task_root, []).append(
task_aggregation_list.setdefault(task_root, []).append(
group_or_task_info.task_id
)
else:
(
results,
versions,
_task_hierarchy,
show_group_table,
_task_aggregation_list,
) = process_group(
results,
versions,
group_or_task_info,
group_or_task,
task_hierarchy,
show_group_table,
task_aggregation_list,
)
if task_root:
task_hierarchy.setdefault(task_root, []).extend(
task_hierarchy.get(group_or_task, [])
task_aggregation_list.setdefault(task_root, []).extend(
task_aggregation_list.get(group_or_task, [])
)
if (group_config is None) or (
......@@ -582,14 +583,14 @@ def evaluate(
show_group_table | group_config["aggregate_metric"]
)
task_list = _task_hierarchy[group_or_task]
task_list = _task_aggregation_list[group_or_task]
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key
and key not in ["alias", "samples"]
and key not in ["task", "alias", "samples"]
}
)
for metric in metric_list:
......@@ -635,13 +636,15 @@ def evaluate(
results[group_or_task]["samples"] = sum(sizes)
versions[group_or_task] = group_config["version"]
return results, versions, task_hierarchy, show_group_table
return results, versions, show_group_table, task_aggregation_list
results, versions, task_hierarchy, show_group_table = process_group(
results, versions, show_group_table, *_ = process_group(
results, versions, task_dict
)
results_agg, group_agg = prepare_print_tasks(task_dict, results)
subtask_list = get_subtask_list(task_dict)
results_dict = {
"results": dict(results_agg.items()),
**(
......@@ -649,7 +652,7 @@ def evaluate(
if (bool(group_agg) & show_group_table)
else {}
),
"group_subtasks": dict(reversed(task_hierarchy.items())),
"group_subtasks": dict(reversed(subtask_list.items())),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
......
......@@ -137,6 +137,50 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
return outputs
def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list = {}
for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup):
group_name = group_obj.group_name
else:
group_name = group_obj
if isinstance(task_obj, dict):
_subtask_list = get_subtask_list(
task_obj, task_root=group_name, depth=depth + 1
)
if task_root:
subtask_list.setdefault((task_root, depth), []).extend(
[
_task
for (_task, _depth) in _subtask_list.keys()
if (_depth - 1) == depth
]
)
subtask_list = {**subtask_list, **_subtask_list}
else:
if isinstance(task_obj, ConfigurableGroup):
group_or_task_name = task_obj.group_name
elif isinstance(task_obj, ConfigurableTask):
group_or_task_name = task_obj.task_name
if task_root is None:
subtask_list.setdefault((group_or_task_name, depth), [])
else:
subtask_list.setdefault((task_root, depth), []).append(
group_or_task_name
)
if depth == 0:
_subtask_list = {}
for group_key, task_list in subtask_list.items():
group_name, depth = group_key
_subtask_list[group_name] = task_list
subtask_list = _subtask_list
return subtask_list
def print_writeout(task) -> None:
for inst in task.instances:
# print the prompt for the first few documents
......@@ -181,16 +225,20 @@ def prepare_print_tasks(
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
# name = task_or_group_name.group
string_name = task_or_group_name.group_name
name = task_or_group_name.task_id
from_configurable_group = True
elif isinstance(task_or_group_name, str):
name = task_or_group_name
if isinstance(task_or_group_obj, ConfigurableTask):
string_name = task_or_group_obj.task_name
name = task_or_group_obj.task_id
from_configurable_group = False
task_agg[name] = results[name].copy()
task_agg[name] = {
**{"task_or_group_name": string_name},
**results[name].copy(),
}
if from_configurable_group:
if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias
......@@ -262,6 +310,7 @@ def consolidate_results(
# Tracks each task's version.
versions = collections.defaultdict(dict)
for task_output in eval_tasks:
# results[task_output.task_id]["task"] = task_output.task_name
if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_id]["alias"] = task_config["task_alias"]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment