Commit 936cd084 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

improve loading a bit

parent 3a32b8c9
......@@ -208,6 +208,7 @@ class ConfigMixin:
def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
expected_keys.remove("kwargs")
init_dict = {}
for key in expected_keys:
if key in kwargs:
......
......@@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module):
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
def __init__(self):
super().__init__()
......
......@@ -63,8 +63,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=30,
**kwargs,
):
super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(
......
......@@ -59,8 +59,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=32,
**kwargs,
):
super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment