Commit fe7d1363 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct dict

parent e660a05f
...@@ -14,13 +14,11 @@ ...@@ -14,13 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" ConfigMixinuration base class and utilities.""" """ ConfigMixinuration base class and utilities."""
import copy
import inspect import inspect
import json import json
import os import os
import re import re
from collections import OrderedDict
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
...@@ -63,10 +61,14 @@ class ConfigMixin: ...@@ -63,10 +61,14 @@ class ConfigMixin:
logger.error(f"Can't set {key} with value {value} for {self}") logger.error(f"Can't set {key} with value {value} for {self}")
raise err raise err
if not hasattr(self, "_dict_to_save"): if not hasattr(self, "_internal_dict"):
self._dict_to_save = {} internal_dict = kwargs
else:
previous_dict = dict(self._internal_dict)
internal_dict = {**self._internal_dict, **kwargs}
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
self._dict_to_save.update(kwargs) self._internal_dict = FrozenDict(internal_dict)
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
""" """
...@@ -230,8 +232,7 @@ class ConfigMixin: ...@@ -230,8 +232,7 @@ class ConfigMixin:
@property @property
def config(self) -> Dict[str, Any]: def config(self) -> Dict[str, Any]:
output = copy.deepcopy(self._dict_to_save) return self._internal_dict
return output
def to_json_string(self) -> str: def to_json_string(self) -> str:
""" """
...@@ -240,7 +241,7 @@ class ConfigMixin: ...@@ -240,7 +241,7 @@ class ConfigMixin:
Returns: Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format. `str`: String containing all the attributes that make up this configuration instance in JSON format.
""" """
config_dict = self._dict_to_save config_dict = self._internal_dict
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]): def to_json_file(self, json_file_path: Union[str, os.PathLike]):
...@@ -253,3 +254,39 @@ class ConfigMixin: ...@@ -253,3 +254,39 @@ class ConfigMixin:
""" """
with open(json_file_path, "w", encoding="utf-8") as writer: with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string()) writer.write(self.to_json_string())
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
# remove `None`
args = (a for a in args if a is not None)
kwargs = {k: v for k, v in kwargs if v is not None}
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)
...@@ -338,7 +338,7 @@ class ModelMixin(torch.nn.Module): ...@@ -338,7 +338,7 @@ class ModelMixin(torch.nn.Module):
revision=revision, revision=revision,
**kwargs, **kwargs,
) )
model.register(name_or_path=pretrained_model_name_or_path) model.register_to_config(name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model # Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
......
...@@ -88,7 +88,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -88,7 +88,7 @@ class DiffusionPipeline(ConfigMixin):
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory) self.save_config(save_directory)
model_index_dict = self.config model_index_dict = dict(self.config)
model_index_dict.pop("_class_name") model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version") model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module") model_index_dict.pop("_module")
......
...@@ -73,6 +73,10 @@ class ConfigTester(unittest.TestCase): ...@@ -73,6 +73,10 @@ class ConfigTester(unittest.TestCase):
new_obj = SampleObject.from_config(tmpdirname) new_obj = SampleObject.from_config(tmpdirname)
new_config = new_obj.config new_config = new_obj.config
# unfreeze configs
config = dict(config)
new_config = dict(new_config)
assert config.pop("c") == (2, 5) # instantiated as tuple assert config.pop("c") == (2, 5) # instantiated as tuple
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config assert config == new_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment