Commit c6839d72 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fixed sorting for multi-level printing

parent e5811879
...@@ -227,14 +227,32 @@ def prepare_print_tasks( ...@@ -227,14 +227,32 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
""" """
def _sort_task_dict(task_dict):
"""
Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
Required so that we end up sorting within each sub-header correctly.
"""
return dict(
sorted(
task_dict.items(),
key=lambda item: item[0].group_name
if isinstance(item[0], ConfigurableGroup)
else item[0],
)
)
task_agg = collections.defaultdict(dict) task_agg = collections.defaultdict(dict)
group_agg = collections.defaultdict(dict) group_agg = collections.defaultdict(dict)
task_dict = _sort_task_dict(task_dict)
for task_or_group_name, task_or_group_obj in task_dict.items(): for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else "" tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup): if isinstance(task_or_group_name, ConfigurableGroup):
# string_name = task_or_group_name.group_name # string_name = task_or_group_name.group_name
name = task_or_group_name.group_name name = task_or_group_name.group_name
from_configurable_group = True from_configurable_group = True
task_or_group_obj = _sort_task_dict(task_or_group_obj)
elif isinstance(task_or_group_name, str): elif isinstance(task_or_group_name, str):
name = task_or_group_name name = task_or_group_name
if isinstance(task_or_group_obj, Task): if isinstance(task_or_group_obj, Task):
......
...@@ -156,7 +156,7 @@ class TaskManager: ...@@ -156,7 +156,7 @@ class TaskManager:
if isinstance(config["class"], ConfigurableTask) if isinstance(config["class"], ConfigurableTask)
else config["class"]() else config["class"]()
) )
# very scuffed: set task name here TODO: fixme? # very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"] task_object.config.task = config["task"]
else: else:
task_object = ConfigurableTask(config=config) task_object = ConfigurableTask(config=config)
......
...@@ -289,7 +289,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False) ...@@ -289,7 +289,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
keys = result_dict[column].keys() keys = result_dict[column].keys()
if sort_results: if sort_results:
# sort entries alphabetically # sort entries alphabetically by task or group name.
# NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
# sorting here would mess that up
keys = sorted(keys) keys = sorted(keys)
for k in keys: for k in keys:
dic = result_dict[column][k] dic = result_dict[column][k]
......
...@@ -39,7 +39,6 @@ dependencies = [ ...@@ -39,7 +39,6 @@ dependencies = [
"dill", "dill",
"word2number", "word2number",
"more_itertools", "more_itertools",
"shortuuid",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment