"examples/t2i_adapter/requirements.txt" did not exist on "f1b726e46e7a66771e6321d00a28048effce1389"
Commit f36fb47f authored by lintangsutawika's avatar lintangsutawika
Browse files

update loading of task group and newly added tags

parent 1fae7283
......@@ -67,7 +67,9 @@ class TaskManager:
return False
def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "task"
):
return True
return False
......@@ -140,9 +142,10 @@ class TaskManager:
name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
yaml_path: Optional[str] = None,
) -> Mapping:
def load_task(config, task):
print(f"loading {task}", self._name_is_registered(task))
print(config)
if "include" in config:
config = {
**utils.load_yaml_config(
......@@ -185,7 +188,7 @@ class TaskManager:
name = name_or_config["task"]
# If the name is registered as a group
# if self._name_is_task(name) is False:
if self._name_is_group(name):
if self._name_is_group(name) or self._name_is_tag(name):
group_name = name
update_config = {
k: v for k, v in name_or_config.items() if k != "task"
......@@ -194,7 +197,9 @@ class TaskManager:
if subtask_list == -1:
group_config = self._get_config(name)
subtask_list = group_config["task"]
group_name = ConfigurableGroup(config=group_config)
group_name = ConfigurableGroup(config=group_config)
else:
group_name = name
else:
if self._name_is_registered(name):
base_task_config = self._get_config(name)
......@@ -230,15 +235,25 @@ class TaskManager:
for k, v in name_or_config.items()
if k not in GROUP_ONLY_KEYS
}
group_name = ConfigurableGroup(config=name_or_config)
else:
update_config = None
group_config = {
k: v
for k, v in name_or_config.items()
if k in GROUP_ONLY_KEYS + ["task", "group"]
}
if bool(group_config):
group_name = ConfigurableGroup(config=group_config)
fn = partial(
self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
yaml_path=yaml_path,
)
return {group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))}
return {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
}
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list
......@@ -282,6 +297,8 @@ class TaskManager:
:return
Dictionary of task names as key and task metadata
"""
# TODO: remove group in next release
print_info = True
tasks_and_groups = collections.defaultdict()
for root, _, file_list in os.walk(task_dir):
for f in file_list:
......@@ -323,20 +340,31 @@ class TaskManager:
"yaml_path": yaml_path,
}
if "tag" in config:
tag = config["tag"]
if isinstance(config["tag"], str):
tag = [tag]
for tag in tag:
if tag not in tasks_and_groups:
tasks_and_groups[tag] = {
"type": "tag",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[tag]["task"].append(task)
# TODO: remove group in next release
for attr in ["tag", "group"]:
if attr in config:
if attr == "group" and print_info:
self.logger.info(
"`group` and `group_alias` will no longer be used in the next release of lm-eval. "
"`tags` will be used to allow to call a collection of tasks just like `group`. "
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup "
"which will be the offical way to create groups with addition of group-wide configuations."
)
print_info = False
attr_list = config[attr]
if isinstance(attr_list, str):
attr_list = [attr_list]
for tag in attr_list:
if tag not in tasks_and_groups:
tasks_and_groups[tag] = {
"type": "tag",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[tag]["task"].append(task)
else:
self.logger.debug(f"File {f} in {root} could not be loaded")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment