Unverified Commit 72ea626e authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Fix group register (#1315)

* tuple should be considered as well

* set option to keep callable as callable
parent b4c6bdb7
...@@ -118,7 +118,7 @@ class TaskConfig(dict): ...@@ -118,7 +118,7 @@ class TaskConfig(dict):
def __setitem__(self, item, value): def __setitem__(self, item, value):
return setattr(self, item, value) return setattr(self, item, value)
def to_dict(self): def to_dict(self, keep_callable=False):
"""dumps the current config as a dictionary object, as a printable format. """dumps the current config as a dictionary object, as a printable format.
null fields will not be printed. null fields will not be printed.
Used for dumping results alongside full task configuration Used for dumping results alongside full task configuration
...@@ -134,8 +134,11 @@ class TaskConfig(dict): ...@@ -134,8 +134,11 @@ class TaskConfig(dict):
if v is None: if v is None:
cfg_dict.pop(k) cfg_dict.pop(k)
elif isinstance(v, Callable): elif isinstance(v, Callable):
# TODO: this should handle Promptsource template objects as a separate case? if keep_callable:
cfg_dict[k] = str(v) cfg_dict[k] = v
else:
# TODO: this should handle Promptsource template objects as a separate case?
cfg_dict[k] = str(v)
return cfg_dict return cfg_dict
......
...@@ -399,7 +399,7 @@ def evaluate( ...@@ -399,7 +399,7 @@ def evaluate(
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
if isinstance(items[0], (str, list)): if isinstance(items[0], (str, list, tuple)):
# handle the string case # handle the string case
gathered_items = [None] * lm.accelerator.num_processes gathered_items = [None] * lm.accelerator.num_processes
torch.distributed.all_gather_object(gathered_items, items) torch.distributed.all_gather_object(gathered_items, items)
......
...@@ -72,7 +72,7 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) - ...@@ -72,7 +72,7 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
_, task_obj = task_obj _, task_obj = task_obj
if task_obj is not None: if task_obj is not None:
base_config = task_obj._config.to_dict() base_config = task_obj._config.to_dict(keep_callable=True)
task_name_config["task"] = f"{group}_{task_name}" task_name_config["task"] = f"{group}_{task_name}"
task_config = utils.load_yaml_config(yaml_path, task_config) task_config = utils.load_yaml_config(yaml_path, task_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment