Commit 2c8a1656 authored by Baber's avatar Baber
Browse files

pass metadata to groups as well

parent 01d89cdc
......@@ -56,15 +56,15 @@ class TaskManager:
self,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
):
"""Creates a dictionary of tasks index.
) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes.
:param include_path: Union[str, List] = None
An additional path to be searched for tasks recursively.
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
:return
return
Dictionary of task names as key and task metadata
"""
if include_defaults:
......@@ -169,54 +169,54 @@ class TaskManager:
result += subtask_table.dumps() + "\n\n"
return result
def match_tasks(self, task_list):
def match_tasks(self, task_list: list[str]) -> list[str]:
return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name) -> bool:
def _name_is_registered(self, name: str) -> bool:
if name in self.all_tasks:
return True
return False
def _name_is_task(self, name) -> bool:
def _name_is_task(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
return True
return False
def _name_is_tag(self, name) -> bool:
def _name_is_tag(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
return True
return False
def _name_is_group(self, name) -> bool:
def _name_is_group(self, name: str) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True
return False
def _name_is_python_task(self, name):
def _name_is_python_task(self, name: str) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True
return False
def _config_is_task(self, config) -> bool:
def _config_is_task(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], str):
return True
return False
def _config_is_group(self, config) -> bool:
def _config_is_group(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], list):
return True
return False
def _config_is_python_task(self, config) -> bool:
def _config_is_python_task(self, config: dict) -> bool:
if "class" in config:
return True
return False
def _get_yaml_path(self, name):
def _get_yaml_path(self, name: str):
if name not in self.task_index:
raise ValueError
return self.task_index[name]["yaml_path"]
......@@ -287,7 +287,11 @@ class TaskManager:
return {task: task_object}
def _get_group_and_subtask_from_config(config):
def _get_group_and_subtask_from_config(
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
if metadata is not None:
config["metadata"] = config.get("metadata", {}) | metadata
group_name = ConfigurableGroup(config=config)
subtask_list = []
for task in group_name.config["task"]:
......@@ -297,7 +301,9 @@ class TaskManager:
subtask_list.append(task)
return group_name, subtask_list
def _process_group_config(config, update_config=None):
def _process_group_config(
config: dict, update_config: dict = None
) -> tuple[dict, dict]:
if update_config is not None:
config = {**config, **update_config}
_update_config = {
......@@ -397,6 +403,7 @@ class TaskManager:
fn = partial(
self._load_individual_task_or_group,
metadata=metadata,
parent_name=group_name,
update_config=update_config,
)
......@@ -561,7 +568,7 @@ def get_task_name_from_object(task_object):
)
def _check_duplicates(task_dict: dict) -> List[str]:
def _check_duplicates(task_dict: dict) -> None:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
......@@ -591,7 +598,7 @@ def _check_duplicates(task_dict: dict) -> List[str]:
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None,
metadata=None,
metadata: dict = None,
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment