"...resnet50_tensorflow.git" did not exist on "37c473238646961a4229e931a2ae51ecc721dced"
Commit 0a184f46 authored by Baber's avatar Baber
Browse files

refactor: enhance task loading by including yaml_path parameter

parent 15c01f4d
...@@ -287,16 +287,16 @@ class TaskManager: ...@@ -287,16 +287,16 @@ class TaskManager:
parent_name: Optional[str] = None, parent_name: Optional[str] = None,
update_config: Optional[dict] = None, update_config: Optional[dict] = None,
) -> Mapping: ) -> Mapping:
def _load_task(config, task): def _load_task(config, task, yaml_path=None):
if "include" in config: if "include" in config:
# Store the task name to preserve it after include processing # Store the task name to preserve it after include processing
original_task_name = config.get("task", task) original_task_name = config.get("task", task)
config = { config = {
**utils.load_yaml_config( **utils.load_yaml_config(
yaml_path=None, yaml_path=yaml_path,
yaml_config={"include": config.pop("include")}, yaml_config={"include": config.pop("include")},
mode="full", mode="full" if yaml_path else "simple",
), ),
**config, **config,
"task": original_task_name, "task": original_task_name,
...@@ -357,6 +357,8 @@ class TaskManager: ...@@ -357,6 +357,8 @@ class TaskManager:
elif self._name_is_task(name_or_config) or self._name_is_python_task( elif self._name_is_task(name_or_config) or self._name_is_python_task(
name_or_config name_or_config
): ):
# Get the yaml_path for this task
yaml_path = self._get_yaml_path(name_or_config)
task_config = self._get_config(name_or_config) task_config = self._get_config(name_or_config)
# Handle task_list configs # Handle task_list configs
...@@ -387,7 +389,7 @@ class TaskManager: ...@@ -387,7 +389,7 @@ class TaskManager:
) )
task_config = {"task": name_or_config} task_config = {"task": name_or_config}
return _load_task(task_config, task=name_or_config) return _load_task(task_config, task=name_or_config, yaml_path=yaml_path)
else: else:
subtask_list = self._get_tasklist(name_or_config) subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1: if subtask_list == -1:
...@@ -427,7 +429,9 @@ class TaskManager: ...@@ -427,7 +429,9 @@ class TaskManager:
elif self._name_is_tag(name): elif self._name_is_tag(name):
return self._process_tag_subtasks(name, name_or_config) return self._process_tag_subtasks(name, name_or_config)
else: else:
yaml_path = None
if self._name_is_registered(name): if self._name_is_registered(name):
yaml_path = self._get_yaml_path(name)
base_task_config = self._get_config(name) base_task_config = self._get_config(name)
# Check if this is a duplicate. # Check if this is a duplicate.
...@@ -450,7 +454,7 @@ class TaskManager: ...@@ -450,7 +454,7 @@ class TaskManager:
} }
else: else:
task_config = name_or_config task_config = name_or_config
return _load_task(task_config, task=name) return _load_task(task_config, task=name, yaml_path=yaml_path)
else: else:
group_config, update_config = _process_group_config(name_or_config) group_config, update_config = _process_group_config(name_or_config)
group_name, subtask_list = _get_group_and_subtask_from_config( group_name, subtask_list = _get_group_and_subtask_from_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment