Commit 6584e6d6 authored by lintangsutawika's avatar lintangsutawika
Browse files

group in a list of group can accept parameter changes like `num_fewshot`

parent db50055e
......@@ -66,7 +66,7 @@ class TaskManager(abc.ABC):
return False
def _name_is_task(self, name):
if self.ALL_TASKS[name]["type"] == "task":
if self._name_is_registered(name) and (self.ALL_TASKS[name]["type"] == "task"):
return True
return False
......@@ -88,7 +88,12 @@ class TaskManager(abc.ABC):
assert self._name_is_task(name) == False
return self.ALL_TASKS[name]["task"]
def _load_individual_task_or_group(self, name_or_config: Union[str, dict] = None, parent_name: str = None) -> ConfigurableTask:
def _load_individual_task_or_group(
self,
name_or_config: Union[str, dict] = None,
parent_name: str = None,
update_config: dict = None
) -> ConfigurableTask:
def load_task(config, task, group=None):
task_object = ConfigurableTask(config=config)
......@@ -97,7 +102,9 @@ class TaskManager(abc.ABC):
return {task: task_object}
if isinstance(name_or_config, str):
if self._name_is_task(name_or_config):
if update_config is not None:
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config)
return load_task(task_config, task=name_or_config, group=parent_name)
else:
......@@ -106,18 +113,32 @@ class TaskManager(abc.ABC):
if subtask_list == -1:
subtask_list = self._get_config(name_or_config)["task"]
elif isinstance(name_or_config, dict):
if isinstance(name_or_config, dict):
if update_config is not None:
name_or_config={
**name_or_config,
**update_config,
}
if self._config_is_task(name_or_config):
task_name = name_or_config["task"]
if self._name_is_registered(task_name):
base_task_config = self._get_config(task_name)
task_config={
**base_task_config,
**name_or_config,
}
name = name_or_config["task"]
if self._name_is_task(name) is False:
group_name = name
update_config = {k:v for k,v in name_or_config.items() if k is not "task"}
subtask_list = self._get_tasklist(name)
if subtask_list == -1:
subtask_list = self._get_config(name)["task"]
else:
task_config = name_or_config
return load_task(task_config, task=task_name, group=parent_name)
if self._name_is_registered(name):
base_task_config = self._get_config(name)
task_config={
**base_task_config,
**name_or_config,
}
else:
task_config = name_or_config
return load_task(task_config, task=name, group=parent_name)
else:
group_name = name_or_config["group"]
subtask_list = name_or_config["task"]
......@@ -127,7 +148,7 @@ class TaskManager(abc.ABC):
else:
all_subtasks = {}
fn = partial(self._load_individual_task_or_group, parent_name=group_name)
fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config)
all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))}
return all_subtasks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment