Commit acc634fa authored by Baber's avatar Baber
Browse files

refactor: simplify task and config validation methods

parent fcddf195
...@@ -175,48 +175,38 @@ class TaskManager: ...@@ -175,48 +175,38 @@ class TaskManager:
return utils.pattern_match(task_list, self.all_tasks) return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name: str) -> bool: def _name_is_registered(self, name: str) -> bool:
if name in self.all_tasks: return name in self.all_tasks
return True
return False
def _name_is_task(self, name: str) -> bool: def _name_is_task(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"): return (
return True self._name_is_registered(name) and self.task_index[name]["type"] == "task"
return False )
def _name_is_tag(self, name: str) -> bool: def _name_is_tag(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"): return self._name_is_registered(name) and self.task_index[name]["type"] == "tag"
return True
return False
def _name_is_group(self, name: str) -> bool: def _name_is_group(self, name: str) -> bool:
if self._name_is_registered(name) and ( return (
self.task_index[name]["type"] == "group" self._name_is_registered(name) and self.task_index[name]["type"] == "group"
): )
return True
return False
def _name_is_python_task(self, name: str) -> bool: def _name_is_python_task(self, name: str) -> bool:
if self._name_is_registered(name) and ( return (
self.task_index[name]["type"] == "python_task" self._name_is_registered(name)
): and self.task_index[name]["type"] == "python_task"
return True )
return False
def _config_is_task(self, config: dict) -> bool: def _config_is_task(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], str): return "task" in config and isinstance(config["task"], str)
return True
return False
def _config_is_group(self, config: dict) -> bool: def _config_is_group(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], list): return "task" in config and isinstance(config["task"], list)
return True
return False
def _config_is_python_task(self, config: dict) -> bool: def _config_is_python_task(self, config: dict) -> bool:
if "class" in config: return "class" in config
return True
return False def _config_is_task_list(self, config: dict) -> bool:
return "task_list" in config and isinstance(config["task_list"], list)
def _get_yaml_path(self, name: str): def _get_yaml_path(self, name: str):
if name not in self.task_index: if name not in self.task_index:
...@@ -237,6 +227,43 @@ class TaskManager: ...@@ -237,6 +227,43 @@ class TaskManager:
raise ValueError raise ValueError
return self.task_index[name]["task"] return self.task_index[name]["task"]
def _register_task(
self,
task_name: str,
task_type: str,
yaml_path: str,
tasks_and_groups: dict,
config: dict = None,
populate_tags_fn=None,
):
"""Helper method to register a task in the tasks_and_groups dict"""
tasks_and_groups[task_name] = {
"type": task_type,
"yaml_path": yaml_path,
}
# Only populate tags for configs that support it (not groups)
if config and task_type != "group" and populate_tags_fn:
populate_tags_fn(config, task_name, tasks_and_groups, True)
def _merge_task_configs(
self, base_config: dict, task_specific_config: dict, task_name: str
) -> dict:
"""Merge base config with task-specific overrides for task_list configs"""
if task_specific_config:
task_specific_config = task_specific_config.copy()
task_specific_config.pop("task", None)
return {**base_config, **task_specific_config, "task": task_name}
return {**base_config, "task": task_name}
def _process_tag_subtasks(self, tag_name: str, update_config: dict = None):
"""Process subtasks for a tag and return loaded tasks"""
subtask_list = self._get_tasklist(tag_name)
fn = partial(
self._load_individual_task_or_group,
update_config=update_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config, group=None): def _process_alias(self, config, group=None):
# If the group is not the same as the original # If the group is not the same as the original
# group which the group alias was intended for, # group which the group alias was intended for,
...@@ -323,6 +350,35 @@ class TaskManager: ...@@ -323,6 +350,35 @@ class TaskManager:
name_or_config name_or_config
): ):
task_config = self._get_config(name_or_config) task_config = self._get_config(name_or_config)
# Handle task_list configs
if "task_list" in task_config:
# Find the specific task entry
task_specific_config = None
for task_entry in task_config["task_list"]:
if (
isinstance(task_entry, dict)
and task_entry.get("task") == name_or_config
):
task_specific_config = task_entry
break
if task_specific_config:
# Create base config without task_list
base_config = {
k: v for k, v in task_config.items() if k != "task_list"
}
# Merge using helper method
task_config = self._merge_task_configs(
base_config, task_specific_config, name_or_config
)
else:
# Task not found in task_list, shouldn't happen if indexing worked correctly
eval_logger.warning(
f"Task {name_or_config} not found in task_list"
)
task_config = {"task": name_or_config}
return _load_task(task_config, task=name_or_config) return _load_task(task_config, task=name_or_config)
else: else:
subtask_list = self._get_tasklist(name_or_config) subtask_list = self._get_tasklist(name_or_config)
...@@ -334,15 +390,12 @@ class TaskManager: ...@@ -334,15 +390,12 @@ class TaskManager:
) )
else: else:
if self._name_is_tag(name_or_config): if self._name_is_tag(name_or_config):
fn = partial( return self._process_tag_subtasks(
self._load_individual_task_or_group, name_or_config,
update_config=name_or_config name_or_config
if isinstance(name_or_config, dict) if isinstance(name_or_config, dict)
else None, else None,
) )
return dict(
collections.ChainMap(*map(fn, reversed(subtask_list)))
)
else: else:
group_name = ConfigurableGroup( group_name = ConfigurableGroup(
config={"group": name_or_config, "task": subtask_list} config={"group": name_or_config, "task": subtask_list}
...@@ -364,12 +417,7 @@ class TaskManager: ...@@ -364,12 +417,7 @@ class TaskManager:
group_config group_config
) )
elif self._name_is_tag(name): elif self._name_is_tag(name):
subtask_list = self._get_tasklist(name) return self._process_tag_subtasks(name, name_or_config)
fn = partial(
self._load_individual_task_or_group,
update_config=name_or_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
else: else:
if self._name_is_registered(name): if self._name_is_registered(name):
base_task_config = self._get_config(name) base_task_config = self._get_config(name)
...@@ -458,7 +506,7 @@ class TaskManager: ...@@ -458,7 +506,7 @@ class TaskManager:
Dictionary of task names as key and task metadata Dictionary of task names as key and task metadata
""" """
def _populate_tags_and_groups(config, task, tasks_and_groups, print_info): def _populate_tags_and_groups(config, task, tasks_and_groups):
# TODO: remove group in next release # TODO: remove group in next release
if "tag" in config: if "tag" in config:
attr_list = config["tag"] attr_list = config["tag"]
...@@ -482,7 +530,6 @@ class TaskManager: ...@@ -482,7 +530,6 @@ class TaskManager:
tasks_and_groups[tag]["task"].append(task) tasks_and_groups[tag]["task"].append(task)
# TODO: remove group in next release # TODO: remove group in next release
print_info = True
ignore_dirs = [ ignore_dirs = [
"__pycache__", "__pycache__",
".ipynb_checkpoints", ".ipynb_checkpoints",
...@@ -497,12 +544,13 @@ class TaskManager: ...@@ -497,12 +544,13 @@ class TaskManager:
if self._config_is_python_task(config): if self._config_is_python_task(config):
# This is a python class config # This is a python class config
task = config["task"] task = config["task"]
tasks_and_groups[task] = { self._register_task(
"type": "python_task", task,
"yaml_path": yaml_path, "python_task",
} yaml_path,
_populate_tags_and_groups( tasks_and_groups,
config, task, tasks_and_groups, print_info config,
_populate_tags_and_groups,
) )
elif self._config_is_group(config): elif self._config_is_group(config):
# This is a group config # This is a group config
...@@ -528,12 +576,26 @@ class TaskManager: ...@@ -528,12 +576,26 @@ class TaskManager:
elif self._config_is_task(config): elif self._config_is_task(config):
# This is a task config # This is a task config
task = config["task"] task = config["task"]
tasks_and_groups[task] = { self._register_task(
"type": "task", task,
"yaml_path": yaml_path, "task",
} yaml_path,
_populate_tags_and_groups( tasks_and_groups,
config, task, tasks_and_groups, print_info config,
_populate_tags_and_groups,
)
elif self._config_is_task_list(config):
# This is a task_list config
for task_entry in config["task_list"]:
if isinstance(task_entry, dict) and "task" in task_entry:
task_name = task_entry["task"]
self._register_task(
task_name,
"task",
yaml_path,
tasks_and_groups,
config,
_populate_tags_and_groups,
) )
else: else:
eval_logger.debug(f"File {f} in {root} could not be loaded") eval_logger.debug(f"File {f} in {root} could not be loaded")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment