Commit fe2a1472 authored by lintangsutawika's avatar lintangsutawika
Browse files

simplify get_task_list

parent 86039e85
......@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union
from lm_eval.api import metrics
from lm_eval.utils import eval_logger, positional_deprecated
from lm_eval.api.task import ConfigurableTask, ConfigurableGroup
class TaskOutput:
"""
......@@ -118,30 +118,18 @@ class TaskOutput:
)
def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]:
# task_hierarchy = collections.defaultdict(list)
task_hierarchy = collections.defaultdict(lambda: collections.defaultdict(list))
def get_task_list(task_dict: dict) -> List[TaskOutput]:
outputs = []
for x, y in task_dict.items():
group, task_obj = y
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
task_output = TaskOutput.from_taskdict(x, (group, None))
task_config = task_obj
else:
task_output = TaskOutput.from_taskdict(x, y)
task_config = task_obj.config.to_dict()
outputs.append(task_output)
if group_name := task_output.group_name:
task_hierarchy[group_name]["tasks"].append(task_output.task_name)
if "group_config" in task_config:
task_hierarchy[group_name]["config"] = task_config["group_config"]
_outputs = get_task_list(task_obj)
outputs.extend(_outputs)
else:
task_hierarchy[task_output.task_name]["tasks"] = []
# returns task_hierarchy tracking which groups contain which subtasks,
# and a list of TaskOutput classes for each non-group subtask
return task_hierarchy, [x for x in outputs if x.task]
task_output = TaskOutput.from_taskdict(task_name, task_obj)
outputs.append(task_output)
return outputs
def print_writeout(task) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment