Commit f36fb47f authored by lintangsutawika's avatar lintangsutawika
Browse files

update loading of task group and newly added tags

parent 1fae7283
...@@ -67,7 +67,9 @@ class TaskManager: ...@@ -67,7 +67,9 @@ class TaskManager:
return False return False
def _name_is_task(self, name) -> bool: def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]): if self._name_is_registered(name) and (
self.task_index[name]["type"] == "task"
):
return True return True
return False return False
...@@ -140,9 +142,10 @@ class TaskManager: ...@@ -140,9 +142,10 @@ class TaskManager:
name_or_config: Optional[Union[str, dict]] = None, name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None, parent_name: Optional[str] = None,
update_config: Optional[dict] = None, update_config: Optional[dict] = None,
yaml_path: Optional[str] = None,
) -> Mapping: ) -> Mapping:
def load_task(config, task): def load_task(config, task):
print(f"loading {task}", self._name_is_registered(task))
print(config)
if "include" in config: if "include" in config:
config = { config = {
**utils.load_yaml_config( **utils.load_yaml_config(
...@@ -185,7 +188,7 @@ class TaskManager: ...@@ -185,7 +188,7 @@ class TaskManager:
name = name_or_config["task"] name = name_or_config["task"]
# If the name is registered as a group # If the name is registered as a group
# if self._name_is_task(name) is False: # if self._name_is_task(name) is False:
if self._name_is_group(name): if self._name_is_group(name) or self._name_is_tag(name):
group_name = name group_name = name
update_config = { update_config = {
k: v for k, v in name_or_config.items() if k != "task" k: v for k, v in name_or_config.items() if k != "task"
...@@ -195,6 +198,8 @@ class TaskManager: ...@@ -195,6 +198,8 @@ class TaskManager:
group_config = self._get_config(name) group_config = self._get_config(name)
subtask_list = group_config["task"] subtask_list = group_config["task"]
group_name = ConfigurableGroup(config=group_config) group_name = ConfigurableGroup(config=group_config)
else:
group_name = name
else: else:
if self._name_is_registered(name): if self._name_is_registered(name):
base_task_config = self._get_config(name) base_task_config = self._get_config(name)
...@@ -230,15 +235,25 @@ class TaskManager: ...@@ -230,15 +235,25 @@ class TaskManager:
for k, v in name_or_config.items() for k, v in name_or_config.items()
if k not in GROUP_ONLY_KEYS if k not in GROUP_ONLY_KEYS
} }
group_name = ConfigurableGroup(config=name_or_config) else:
update_config = None
group_config = {
k: v
for k, v in name_or_config.items()
if k in GROUP_ONLY_KEYS + ["task", "group"]
}
if bool(group_config):
group_name = ConfigurableGroup(config=group_config)
fn = partial( fn = partial(
self._load_individual_task_or_group, self._load_individual_task_or_group,
parent_name=group_name, parent_name=group_name,
update_config=update_config, update_config=update_config,
yaml_path=yaml_path,
) )
return {group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))} return {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
}
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
...@@ -282,6 +297,8 @@ class TaskManager: ...@@ -282,6 +297,8 @@ class TaskManager:
:return :return
Dictionary of task names as key and task metadata Dictionary of task names as key and task metadata
""" """
# TODO: remove group in next release
print_info = True
tasks_and_groups = collections.defaultdict() tasks_and_groups = collections.defaultdict()
for root, _, file_list in os.walk(task_dir): for root, _, file_list in os.walk(task_dir):
for f in file_list: for f in file_list:
...@@ -323,12 +340,23 @@ class TaskManager: ...@@ -323,12 +340,23 @@ class TaskManager:
"yaml_path": yaml_path, "yaml_path": yaml_path,
} }
if "tag" in config: # TODO: remove group in next release
tag = config["tag"] for attr in ["tag", "group"]:
if isinstance(config["tag"], str): if attr in config:
tag = [tag] if attr == "group" and print_info:
self.logger.info(
"`group` and `group_alias` will no longer be used in the next release of lm-eval. "
"`tags` will be used to allow to call a collection of tasks just like `group`. "
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup "
"which will be the offical way to create groups with addition of group-wide configuations."
)
print_info = False
attr_list = config[attr]
if isinstance(attr_list, str):
attr_list = [attr_list]
for tag in tag: for tag in attr_list:
if tag not in tasks_and_groups: if tag not in tasks_and_groups:
tasks_and_groups[tag] = { tasks_and_groups[tag] = {
"type": "tag", "type": "tag",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment