"src/vscode:/vscode.git/clone" did not exist on "ed759f0aee721f8520c5bf94d4b7bd7c0ae3dcbb"
Commit 936cd084 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

improve loading a bit

parent 3a32b8c9
...@@ -208,6 +208,7 @@ class ConfigMixin: ...@@ -208,6 +208,7 @@ class ConfigMixin:
def extract_init_dict(cls, config_dict, **kwargs): def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self") expected_keys.remove("self")
expected_keys.remove("kwargs")
init_dict = {} init_dict = {}
for key in expected_keys: for key in expected_keys:
if key in kwargs: if key in kwargs:
......
...@@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module): ...@@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module):
models, `pixel_values` for vision models and `input_values` for speech models). models, `pixel_values` for vision models and `input_values` for speech models).
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -63,8 +63,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -63,8 +63,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1, mid_block_scale_factor=1,
center_input_sample=False, center_input_sample=False,
resnet_num_groups=30, resnet_num_groups=30,
**kwargs,
): ):
super().__init__() super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>` # register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code # should probably be automated down the road as this is pure boiler plate code
self.register_to_config( self.register_to_config(
......
...@@ -59,8 +59,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -59,8 +59,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1, mid_block_scale_factor=1,
center_input_sample=False, center_input_sample=False,
resnet_num_groups=32, resnet_num_groups=32,
**kwargs,
): ):
super().__init__() super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>` # register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code # should probably be automated down the road as this is pure boiler plate code
self.register_to_config( self.register_to_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment