Commit 2c8a1656 authored by Baber's avatar Baber
Browse files

pass metadata to groups as well

parent 01d89cdc
...@@ -56,15 +56,15 @@ class TaskManager: ...@@ -56,15 +56,15 @@ class TaskManager:
self, self,
include_path: Optional[Union[str, List]] = None, include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True, include_defaults: bool = True,
): ) -> dict[str, dict]:
"""Creates a dictionary of tasks index. """Creates a dictionary of tasks indexes.
:param include_path: Union[str, List] = None :param include_path: Union[str, List] = None
An additional path to be searched for tasks recursively. An additional path to be searched for tasks recursively.
Can provide more than one such path as a list. Can provide more than one such path as a list.
:param include_defaults: bool = True :param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed. If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
:return return
Dictionary of task names as key and task metadata Dictionary of task names as key and task metadata
""" """
if include_defaults: if include_defaults:
...@@ -169,54 +169,54 @@ class TaskManager: ...@@ -169,54 +169,54 @@ class TaskManager:
result += subtask_table.dumps() + "\n\n" result += subtask_table.dumps() + "\n\n"
return result return result
def match_tasks(self, task_list): def match_tasks(self, task_list: list[str]) -> list[str]:
return utils.pattern_match(task_list, self.all_tasks) return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name) -> bool: def _name_is_registered(self, name: str) -> bool:
if name in self.all_tasks: if name in self.all_tasks:
return True return True
return False return False
def _name_is_task(self, name) -> bool: def _name_is_task(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"): if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
return True return True
return False return False
def _name_is_tag(self, name) -> bool: def _name_is_tag(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"): if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
return True return True
return False return False
def _name_is_group(self, name) -> bool: def _name_is_group(self, name: str) -> bool:
if self._name_is_registered(name) and ( if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group" self.task_index[name]["type"] == "group"
): ):
return True return True
return False return False
def _name_is_python_task(self, name): def _name_is_python_task(self, name: str) -> bool:
if self._name_is_registered(name) and ( if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task" self.task_index[name]["type"] == "python_task"
): ):
return True return True
return False return False
def _config_is_task(self, config) -> bool: def _config_is_task(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], str): if ("task" in config) and isinstance(config["task"], str):
return True return True
return False return False
def _config_is_group(self, config) -> bool: def _config_is_group(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], list): if ("task" in config) and isinstance(config["task"], list):
return True return True
return False return False
def _config_is_python_task(self, config) -> bool: def _config_is_python_task(self, config: dict) -> bool:
if "class" in config: if "class" in config:
return True return True
return False return False
def _get_yaml_path(self, name): def _get_yaml_path(self, name: str):
if name not in self.task_index: if name not in self.task_index:
raise ValueError raise ValueError
return self.task_index[name]["yaml_path"] return self.task_index[name]["yaml_path"]
...@@ -287,7 +287,11 @@ class TaskManager: ...@@ -287,7 +287,11 @@ class TaskManager:
return {task: task_object} return {task: task_object}
def _get_group_and_subtask_from_config(config): def _get_group_and_subtask_from_config(
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
if metadata is not None:
config["metadata"] = config.get("metadata", {}) | metadata
group_name = ConfigurableGroup(config=config) group_name = ConfigurableGroup(config=config)
subtask_list = [] subtask_list = []
for task in group_name.config["task"]: for task in group_name.config["task"]:
...@@ -297,7 +301,9 @@ class TaskManager: ...@@ -297,7 +301,9 @@ class TaskManager:
subtask_list.append(task) subtask_list.append(task)
return group_name, subtask_list return group_name, subtask_list
def _process_group_config(config, update_config=None): def _process_group_config(
config: dict, update_config: dict = None
) -> tuple[dict, dict]:
if update_config is not None: if update_config is not None:
config = {**config, **update_config} config = {**config, **update_config}
_update_config = { _update_config = {
...@@ -397,6 +403,7 @@ class TaskManager: ...@@ -397,6 +403,7 @@ class TaskManager:
fn = partial( fn = partial(
self._load_individual_task_or_group, self._load_individual_task_or_group,
metadata=metadata,
parent_name=group_name, parent_name=group_name,
update_config=update_config, update_config=update_config,
) )
...@@ -561,7 +568,7 @@ def get_task_name_from_object(task_object): ...@@ -561,7 +568,7 @@ def get_task_name_from_object(task_object):
) )
def _check_duplicates(task_dict: dict) -> List[str]: def _check_duplicates(task_dict: dict) -> None:
"""helper function solely used in validating get_task_dict output. """helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
...@@ -591,7 +598,7 @@ def _check_duplicates(task_dict: dict) -> List[str]: ...@@ -591,7 +598,7 @@ def _check_duplicates(task_dict: dict) -> List[str]:
def get_task_dict( def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]], task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
metadata=None, metadata: dict = None,
): ):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment