Commit 110e5a28 authored by lintangsutawika's avatar lintangsutawika
Browse files

move prepare_print_tasks to evaluator_utils

parent 75dfac43
......@@ -17,6 +17,7 @@ from lm_eval.evaluator_utils import (
consolidate_results,
get_sample_size,
get_task_list,
prepare_print_tasks,
print_writeout,
run_task_tests,
)
......@@ -599,71 +600,7 @@ def evaluate(
results, task_dict
)
def print_table(task_dict, results, task_depth=0, group_depth=0):
"""
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results.
@param task_depth: The indentation level for printing the task
hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
task_agg = defaultdict(dict)
group_agg = defaultdict(dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
name = task_or_group_name.group
from_configurable_group = True
elif isinstance(task_or_group_name, str):
name = task_or_group_name
from_configurable_group = False
task_agg[name] = results[name].copy()
if from_configurable_group:
if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias
else:
alias = name
else:
if "alias" in task_agg[name]:
alias = task_agg[name]["alias"]
else:
alias = name
task_agg[name]["alias"] = tab_string + alias
if "samples" in task_agg[name]:
task_agg[name].pop("samples")
if from_configurable_group and (" " not in results[name]):
group_tab_string = (
" " * group_depth + "- " if group_depth > 0 else ""
)
group_agg[name] = results[name].copy()
group_agg[name]["alias"] = group_tab_string + alias
if "samples" in group_agg[name]:
group_agg[name].pop("samples")
if isinstance(task_or_group_obj, dict):
task_depth += 1
group_depth += 1
_task_agg, _group_agg = print_table(
task_or_group_obj, results, task_depth, group_depth
)
task_agg = {
**task_agg,
**_task_agg,
}
group_agg = {**group_agg, **_group_agg}
task_depth -= 1
group_depth -= 1
return task_agg, group_agg
results_agg, group_agg = print_table(task_dict, results)
results_agg, group_agg = prepare_print_tasks(task_dict, results)
results_dict = {
"results": dict(results_agg.items()),
**(
......
......@@ -6,7 +6,7 @@ from typing import List, Optional, Tuple, Union
from lm_eval.api import metrics
from lm_eval.utils import eval_logger, positional_deprecated
from lm_eval.tasks import ConfigurableGroup
class TaskOutput:
"""
......@@ -151,76 +151,75 @@ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
def prepare_print_tasks(
task_hierarchy: dict,
task_dict: dict,
results: dict,
tab=0,
group_tab=0,
task_depth=0,
group_depth=0,
) -> Tuple[dict, dict]:
"""
@param task_hierarchy: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results.
@param tab: The indentation level for printing the task
@param task_depth: The indentation level for printing the task
hierarchy. Default is 0.
@param group_depth: The indentation level for printing the group
hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
(group_name, task_list), *_ = task_hierarchy.items()
task_list = sorted(task_list)
results_agg[group_name] = results[group_name].copy()
# results_agg[group_name]["tab"] = tab
if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
tab_string = " " * tab + "- " if tab > 0 else ""
if "alias" in results_agg[group_name]:
results_agg[group_name]["alias"] = tab_string + results_agg[group_name]["alias"]
else:
results_agg[group_name]["alias"] = tab_string + group_name
if len(task_list) > 0:
if " " not in results[group_name]:
group_tab_string = " " * group_tab + "- " if group_tab > 0 else ""
groups_agg[group_name] = results[group_name].copy()
group_tab += 1
if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if "alias" in groups_agg[group_name]:
groups_agg[group_name]["alias"] = (
group_tab_string + groups_agg[group_name]["alias"]
)
task_agg = collections.defaultdict(dict)
group_agg = collections.defaultdict(dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
name = task_or_group_name.group
from_configurable_group = True
elif isinstance(task_or_group_name, str):
name = task_or_group_name
from_configurable_group = False
task_agg[name] = results[name].copy()
if from_configurable_group:
if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias
else:
groups_agg[group_name]["alias"] = group_tab_string + group_name
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
alias = name
else:
if "alias" in task_agg[name]:
alias = task_agg[name]["alias"]
else:
_task_hierarchy = {
**{task_name: []},
**task_hierarchy,
}
alias = name
_results_agg, _groups_agg = prepare_print_tasks(
_task_hierarchy, results, tab + 1, group_tab
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
task_agg[name]["alias"] = tab_string + alias
if "samples" in task_agg[name]:
task_agg[name].pop("samples")
return results_agg, groups_agg
if from_configurable_group and (" " not in results[name]):
group_tab_string = (
" " * group_depth + "- " if group_depth > 0 else ""
)
group_agg[name] = results[name].copy()
group_agg[name]["alias"] = group_tab_string + alias
if "samples" in group_agg[name]:
group_agg[name].pop("samples")
if isinstance(task_or_group_obj, dict):
task_depth += 1
group_depth += 1
_task_agg, _group_agg = prepare_print_tasks(
task_or_group_obj, results, task_depth, group_depth
)
task_agg = {
**task_agg,
**_task_agg,
}
group_agg = {**group_agg, **_group_agg}
task_depth -= 1
group_depth -= 1
return task_agg, group_agg
def consolidate_results(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment