Commit c6839d72 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fixed sorting for multi-level printing

parent e5811879
......@@ -227,14 +227,32 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
def _sort_task_dict(task_dict):
"""
Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
Required so that we end up sorting within each sub-header correctly.
"""
return dict(
sorted(
task_dict.items(),
key=lambda item: item[0].group_name
if isinstance(item[0], ConfigurableGroup)
else item[0],
)
)
task_agg = collections.defaultdict(dict)
group_agg = collections.defaultdict(dict)
task_dict = _sort_task_dict(task_dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
# string_name = task_or_group_name.group_name
name = task_or_group_name.group_name
from_configurable_group = True
task_or_group_obj = _sort_task_dict(task_or_group_obj)
elif isinstance(task_or_group_name, str):
name = task_or_group_name
if isinstance(task_or_group_obj, Task):
......
......@@ -156,7 +156,7 @@ class TaskManager:
if isinstance(config["class"], ConfigurableTask)
else config["class"]()
)
# very scuffed: set task name here TODO: fixme?
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"]
else:
task_object = ConfigurableTask(config=config)
......
......@@ -289,7 +289,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
keys = result_dict[column].keys()
if sort_results:
# sort entries alphabetically
# sort entries alphabetically by task or group name.
# NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
# sorting here would mess that up
keys = sorted(keys)
for k in keys:
dic = result_dict[column][k]
......
......@@ -39,7 +39,6 @@ dependencies = [
"dill",
"word2number",
"more_itertools",
"shortuuid",
]
[tool.setuptools.packages.find]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment