Commit acc634fa authored by Baber's avatar Baber
Browse files

refactor: simplify task and config validation methods

parent fcddf195
......@@ -175,48 +175,38 @@ class TaskManager:
return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name: str) -> bool:
if name in self.all_tasks:
return True
return False
return name in self.all_tasks
def _name_is_task(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
return True
return False
return (
self._name_is_registered(name) and self.task_index[name]["type"] == "task"
)
def _name_is_tag(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
return True
return False
return self._name_is_registered(name) and self.task_index[name]["type"] == "tag"
def _name_is_group(self, name: str) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True
return False
return (
self._name_is_registered(name) and self.task_index[name]["type"] == "group"
)
def _name_is_python_task(self, name: str) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True
return False
return (
self._name_is_registered(name)
and self.task_index[name]["type"] == "python_task"
)
def _config_is_task(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], str):
return True
return False
return "task" in config and isinstance(config["task"], str)
def _config_is_group(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], list):
return True
return False
return "task" in config and isinstance(config["task"], list)
def _config_is_python_task(self, config: dict) -> bool:
if "class" in config:
return True
return False
return "class" in config
def _config_is_task_list(self, config: dict) -> bool:
return "task_list" in config and isinstance(config["task_list"], list)
def _get_yaml_path(self, name: str):
if name not in self.task_index:
......@@ -237,6 +227,43 @@ class TaskManager:
raise ValueError
return self.task_index[name]["task"]
def _register_task(
self,
task_name: str,
task_type: str,
yaml_path: str,
tasks_and_groups: dict,
config: dict = None,
populate_tags_fn=None,
):
"""Helper method to register a task in the tasks_and_groups dict"""
tasks_and_groups[task_name] = {
"type": task_type,
"yaml_path": yaml_path,
}
# Only populate tags for configs that support it (not groups)
if config and task_type != "group" and populate_tags_fn:
populate_tags_fn(config, task_name, tasks_and_groups, True)
def _merge_task_configs(
self, base_config: dict, task_specific_config: dict, task_name: str
) -> dict:
"""Merge base config with task-specific overrides for task_list configs"""
if task_specific_config:
task_specific_config = task_specific_config.copy()
task_specific_config.pop("task", None)
return {**base_config, **task_specific_config, "task": task_name}
return {**base_config, "task": task_name}
def _process_tag_subtasks(self, tag_name: str, update_config: dict = None):
"""Process subtasks for a tag and return loaded tasks"""
subtask_list = self._get_tasklist(tag_name)
fn = partial(
self._load_individual_task_or_group,
update_config=update_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config, group=None):
# If the group is not the same as the original
# group which the group alias was intended for,
......@@ -323,6 +350,35 @@ class TaskManager:
name_or_config
):
task_config = self._get_config(name_or_config)
# Handle task_list configs
if "task_list" in task_config:
# Find the specific task entry
task_specific_config = None
for task_entry in task_config["task_list"]:
if (
isinstance(task_entry, dict)
and task_entry.get("task") == name_or_config
):
task_specific_config = task_entry
break
if task_specific_config:
# Create base config without task_list
base_config = {
k: v for k, v in task_config.items() if k != "task_list"
}
# Merge using helper method
task_config = self._merge_task_configs(
base_config, task_specific_config, name_or_config
)
else:
# Task not found in task_list, shouldn't happen if indexing worked correctly
eval_logger.warning(
f"Task {name_or_config} not found in task_list"
)
task_config = {"task": name_or_config}
return _load_task(task_config, task=name_or_config)
else:
subtask_list = self._get_tasklist(name_or_config)
......@@ -334,15 +390,12 @@ class TaskManager:
)
else:
if self._name_is_tag(name_or_config):
fn = partial(
self._load_individual_task_or_group,
update_config=name_or_config
return self._process_tag_subtasks(
name_or_config,
name_or_config
if isinstance(name_or_config, dict)
else None,
)
return dict(
collections.ChainMap(*map(fn, reversed(subtask_list)))
)
else:
group_name = ConfigurableGroup(
config={"group": name_or_config, "task": subtask_list}
......@@ -364,12 +417,7 @@ class TaskManager:
group_config
)
elif self._name_is_tag(name):
subtask_list = self._get_tasklist(name)
fn = partial(
self._load_individual_task_or_group,
update_config=name_or_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
return self._process_tag_subtasks(name, name_or_config)
else:
if self._name_is_registered(name):
base_task_config = self._get_config(name)
......@@ -458,7 +506,7 @@ class TaskManager:
Dictionary of task names as key and task metadata
"""
def _populate_tags_and_groups(config, task, tasks_and_groups, print_info):
def _populate_tags_and_groups(config, task, tasks_and_groups):
# TODO: remove group in next release
if "tag" in config:
attr_list = config["tag"]
......@@ -482,7 +530,6 @@ class TaskManager:
tasks_and_groups[tag]["task"].append(task)
# TODO: remove group in next release
print_info = True
ignore_dirs = [
"__pycache__",
".ipynb_checkpoints",
......@@ -497,12 +544,13 @@ class TaskManager:
if self._config_is_python_task(config):
# This is a python class config
task = config["task"]
tasks_and_groups[task] = {
"type": "python_task",
"yaml_path": yaml_path,
}
_populate_tags_and_groups(
config, task, tasks_and_groups, print_info
self._register_task(
task,
"python_task",
yaml_path,
tasks_and_groups,
config,
_populate_tags_and_groups,
)
elif self._config_is_group(config):
# This is a group config
......@@ -528,13 +576,27 @@ class TaskManager:
elif self._config_is_task(config):
# This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
_populate_tags_and_groups(
config, task, tasks_and_groups, print_info
self._register_task(
task,
"task",
yaml_path,
tasks_and_groups,
config,
_populate_tags_and_groups,
)
elif self._config_is_task_list(config):
# This is a task_list config
for task_entry in config["task_list"]:
if isinstance(task_entry, dict) and "task" in task_entry:
task_name = task_entry["task"]
self._register_task(
task_name,
"task",
yaml_path,
tasks_and_groups,
config,
_populate_tags_and_groups,
)
else:
eval_logger.debug(f"File {f} in {root} could not be loaded")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment