Commit 88fea8ad authored by lintangsutawika's avatar lintangsutawika
Browse files

add get_subtask_list function to get proper subtask list

parent c90655d5
...@@ -17,6 +17,7 @@ from lm_eval.caching.cache import delete_cache ...@@ -17,6 +17,7 @@ from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import ( from lm_eval.evaluator_utils import (
consolidate_results, consolidate_results,
get_sample_size, get_sample_size,
get_subtask_list,
get_task_list, get_task_list,
prepare_print_tasks, prepare_print_tasks,
print_writeout, print_writeout,
...@@ -531,14 +532,14 @@ def evaluate( ...@@ -531,14 +532,14 @@ def evaluate(
versions, versions,
task_dict, task_dict,
task_root=None, task_root=None,
task_hierarchy=None,
show_group_table=False, show_group_table=False,
task_aggregation_list=None,
): ):
if task_root is None: if task_root is None:
task_root = {} task_root = {}
if task_hierarchy is None: if task_aggregation_list is None:
task_hierarchy = {} task_aggregation_list = {}
for group_or_task, group_or_task_info in task_dict.items(): for group_or_task, group_or_task_info in task_dict.items():
# Convert to string # Convert to string
...@@ -550,26 +551,26 @@ def evaluate( ...@@ -550,26 +551,26 @@ def evaluate(
if isinstance(group_or_task_info, ConfigurableTask): if isinstance(group_or_task_info, ConfigurableTask):
if task_root: if task_root:
task_hierarchy.setdefault(task_root, []).append( task_aggregation_list.setdefault(task_root, []).append(
group_or_task_info.task_id group_or_task_info.task_id
) )
else: else:
( (
results, results,
versions, versions,
_task_hierarchy,
show_group_table, show_group_table,
_task_aggregation_list,
) = process_group( ) = process_group(
results, results,
versions, versions,
group_or_task_info, group_or_task_info,
group_or_task, group_or_task,
task_hierarchy,
show_group_table, show_group_table,
task_aggregation_list,
) )
if task_root: if task_root:
task_hierarchy.setdefault(task_root, []).extend( task_aggregation_list.setdefault(task_root, []).extend(
task_hierarchy.get(group_or_task, []) task_aggregation_list.get(group_or_task, [])
) )
if (group_config is None) or ( if (group_config is None) or (
...@@ -582,14 +583,14 @@ def evaluate( ...@@ -582,14 +583,14 @@ def evaluate(
show_group_table | group_config["aggregate_metric"] show_group_table | group_config["aggregate_metric"]
) )
task_list = _task_hierarchy[group_or_task] task_list = _task_aggregation_list[group_or_task]
metric_list = list( metric_list = list(
{ {
key key
for task in task_list for task in task_list
for key in results[task].keys() for key in results[task].keys()
if "_stderr" not in key if "_stderr" not in key
and key not in ["alias", "samples"] and key not in ["task", "alias", "samples"]
} }
) )
for metric in metric_list: for metric in metric_list:
...@@ -635,13 +636,15 @@ def evaluate( ...@@ -635,13 +636,15 @@ def evaluate(
results[group_or_task]["samples"] = sum(sizes) results[group_or_task]["samples"] = sum(sizes)
versions[group_or_task] = group_config["version"] versions[group_or_task] = group_config["version"]
return results, versions, task_hierarchy, show_group_table return results, versions, show_group_table, task_aggregation_list
results, versions, task_hierarchy, show_group_table = process_group( results, versions, show_group_table, *_ = process_group(
results, versions, task_dict results, versions, task_dict
) )
results_agg, group_agg = prepare_print_tasks(task_dict, results) results_agg, group_agg = prepare_print_tasks(task_dict, results)
subtask_list = get_subtask_list(task_dict)
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**( **(
...@@ -649,7 +652,7 @@ def evaluate( ...@@ -649,7 +652,7 @@ def evaluate(
if (bool(group_agg) & show_group_table) if (bool(group_agg) & show_group_table)
else {} else {}
), ),
"group_subtasks": dict(reversed(task_hierarchy.items())), "group_subtasks": dict(reversed(subtask_list.items())),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())), "n-shot": dict(sorted(num_fewshot.items())),
......
...@@ -137,6 +137,50 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]: ...@@ -137,6 +137,50 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
return outputs return outputs
def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list = {}
for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup):
group_name = group_obj.group_name
else:
group_name = group_obj
if isinstance(task_obj, dict):
_subtask_list = get_subtask_list(
task_obj, task_root=group_name, depth=depth + 1
)
if task_root:
subtask_list.setdefault((task_root, depth), []).extend(
[
_task
for (_task, _depth) in _subtask_list.keys()
if (_depth - 1) == depth
]
)
subtask_list = {**subtask_list, **_subtask_list}
else:
if isinstance(task_obj, ConfigurableGroup):
group_or_task_name = task_obj.group_name
elif isinstance(task_obj, ConfigurableTask):
group_or_task_name = task_obj.task_name
if task_root is None:
subtask_list.setdefault((group_or_task_name, depth), [])
else:
subtask_list.setdefault((task_root, depth), []).append(
group_or_task_name
)
if depth == 0:
_subtask_list = {}
for group_key, task_list in subtask_list.items():
group_name, depth = group_key
_subtask_list[group_name] = task_list
subtask_list = _subtask_list
return subtask_list
def print_writeout(task) -> None: def print_writeout(task) -> None:
for inst in task.instances: for inst in task.instances:
# print the prompt for the first few documents # print the prompt for the first few documents
...@@ -181,16 +225,20 @@ def prepare_print_tasks( ...@@ -181,16 +225,20 @@ def prepare_print_tasks(
for task_or_group_name, task_or_group_obj in task_dict.items(): for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else "" tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup): if isinstance(task_or_group_name, ConfigurableGroup):
# name = task_or_group_name.group string_name = task_or_group_name.group_name
name = task_or_group_name.task_id name = task_or_group_name.task_id
from_configurable_group = True from_configurable_group = True
elif isinstance(task_or_group_name, str): elif isinstance(task_or_group_name, str):
name = task_or_group_name name = task_or_group_name
if isinstance(task_or_group_obj, ConfigurableTask): if isinstance(task_or_group_obj, ConfigurableTask):
string_name = task_or_group_obj.task_name
name = task_or_group_obj.task_id name = task_or_group_obj.task_id
from_configurable_group = False from_configurable_group = False
task_agg[name] = results[name].copy() task_agg[name] = {
**{"task_or_group_name": string_name},
**results[name].copy(),
}
if from_configurable_group: if from_configurable_group:
if task_or_group_name.group_alias is not None: if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias alias = task_or_group_name.group_alias
...@@ -262,6 +310,7 @@ def consolidate_results( ...@@ -262,6 +310,7 @@ def consolidate_results(
# Tracks each task's version. # Tracks each task's version.
versions = collections.defaultdict(dict) versions = collections.defaultdict(dict)
for task_output in eval_tasks: for task_output in eval_tasks:
# results[task_output.task_id]["task"] = task_output.task_name
if "task_alias" in (task_config := task_output.task_config): if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_id]["alias"] = task_config["task_alias"] results[task_output.task_id]["alias"] = task_config["task_alias"]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment