Commit 110e5a28 authored by lintangsutawika's avatar lintangsutawika
Browse files

move prepare_print_tasks to evaluator_utils

parent 75dfac43
...@@ -17,6 +17,7 @@ from lm_eval.evaluator_utils import ( ...@@ -17,6 +17,7 @@ from lm_eval.evaluator_utils import (
consolidate_results, consolidate_results,
get_sample_size, get_sample_size,
get_task_list, get_task_list,
prepare_print_tasks,
print_writeout, print_writeout,
run_task_tests, run_task_tests,
) )
...@@ -599,71 +600,7 @@ def evaluate( ...@@ -599,71 +600,7 @@ def evaluate(
results, task_dict results, task_dict
) )
def print_table(task_dict, results, task_depth=0, group_depth=0): results_agg, group_agg = prepare_print_tasks(task_dict, results)
"""
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results.
@param task_depth: The indentation level for printing the task
hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
task_agg = defaultdict(dict)
group_agg = defaultdict(dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
name = task_or_group_name.group
from_configurable_group = True
elif isinstance(task_or_group_name, str):
name = task_or_group_name
from_configurable_group = False
task_agg[name] = results[name].copy()
if from_configurable_group:
if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias
else:
alias = name
else:
if "alias" in task_agg[name]:
alias = task_agg[name]["alias"]
else:
alias = name
task_agg[name]["alias"] = tab_string + alias
if "samples" in task_agg[name]:
task_agg[name].pop("samples")
if from_configurable_group and (" " not in results[name]):
group_tab_string = (
" " * group_depth + "- " if group_depth > 0 else ""
)
group_agg[name] = results[name].copy()
group_agg[name]["alias"] = group_tab_string + alias
if "samples" in group_agg[name]:
group_agg[name].pop("samples")
if isinstance(task_or_group_obj, dict):
task_depth += 1
group_depth += 1
_task_agg, _group_agg = print_table(
task_or_group_obj, results, task_depth, group_depth
)
task_agg = {
**task_agg,
**_task_agg,
}
group_agg = {**group_agg, **_group_agg}
task_depth -= 1
group_depth -= 1
return task_agg, group_agg
results_agg, group_agg = print_table(task_dict, results)
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**( **(
......
...@@ -6,7 +6,7 @@ from typing import List, Optional, Tuple, Union ...@@ -6,7 +6,7 @@ from typing import List, Optional, Tuple, Union
from lm_eval.api import metrics from lm_eval.api import metrics
from lm_eval.utils import eval_logger, positional_deprecated from lm_eval.utils import eval_logger, positional_deprecated
from lm_eval.tasks import ConfigurableGroup
class TaskOutput: class TaskOutput:
""" """
...@@ -151,76 +151,75 @@ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]: ...@@ -151,76 +151,75 @@ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
def prepare_print_tasks( def prepare_print_tasks(
task_hierarchy: dict, task_dict: dict,
results: dict, results: dict,
tab=0, task_depth=0,
group_tab=0, group_depth=0,
) -> Tuple[dict, dict]: ) -> Tuple[dict, dict]:
""" """
@param task_hierarchy: Dictionary representing the group hierarchy of tasks. Each key is a group name and its @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names. value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a @param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results. group name and its value is a dictionary of task results.
@param tab: The indentation level for printing the task @param task_depth: The indentation level for printing the task
hierarchy. Default is 0.
@param group_depth: The indentation level for printing the group
hierarchy. Default is 0. hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group. aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
""" """
results_agg = collections.defaultdict(dict) task_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict) group_agg = collections.defaultdict(dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
(group_name, task_list), *_ = task_hierarchy.items() tab_string = " " * task_depth + "- " if task_depth > 0 else ""
task_list = sorted(task_list) if isinstance(task_or_group_name, ConfigurableGroup):
name = task_or_group_name.group
results_agg[group_name] = results[group_name].copy() from_configurable_group = True
# results_agg[group_name]["tab"] = tab elif isinstance(task_or_group_name, str):
if "samples" in results_agg[group_name]: name = task_or_group_name
results_agg[group_name].pop("samples") from_configurable_group = False
tab_string = " " * tab + "- " if tab > 0 else "" task_agg[name] = results[name].copy()
if from_configurable_group:
if "alias" in results_agg[group_name]: if task_or_group_name.group_alias is not None:
results_agg[group_name]["alias"] = tab_string + results_agg[group_name]["alias"] alias = task_or_group_name.group_alias
else: else:
results_agg[group_name]["alias"] = tab_string + group_name alias = name
if len(task_list) > 0:
if " " not in results[group_name]:
group_tab_string = " " * group_tab + "- " if group_tab > 0 else ""
groups_agg[group_name] = results[group_name].copy()
group_tab += 1
if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if "alias" in groups_agg[group_name]:
groups_agg[group_name]["alias"] = (
group_tab_string + groups_agg[group_name]["alias"]
)
else: else:
groups_agg[group_name]["alias"] = group_tab_string + group_name if "alias" in task_agg[name]:
alias = task_agg[name]["alias"]
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else: else:
_task_hierarchy = { alias = name
**{task_name: []},
**task_hierarchy,
}
_results_agg, _groups_agg = prepare_print_tasks( task_agg[name]["alias"] = tab_string + alias
_task_hierarchy, results, tab + 1, group_tab if "samples" in task_agg[name]:
) task_agg[name].pop("samples")
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg if from_configurable_group and (" " not in results[name]):
group_tab_string = (
" " * group_depth + "- " if group_depth > 0 else ""
)
group_agg[name] = results[name].copy()
group_agg[name]["alias"] = group_tab_string + alias
if "samples" in group_agg[name]:
group_agg[name].pop("samples")
if isinstance(task_or_group_obj, dict):
task_depth += 1
group_depth += 1
_task_agg, _group_agg = prepare_print_tasks(
task_or_group_obj, results, task_depth, group_depth
)
task_agg = {
**task_agg,
**_task_agg,
}
group_agg = {**group_agg, **_group_agg}
task_depth -= 1
group_depth -= 1
return task_agg, group_agg
def consolidate_results( def consolidate_results(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment