"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "960c149c777ea1584cd5584eac832ec9810b2632"
Unverified Commit ad9d2525 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a decorator for register_to_config (#108)

* Add a decorator for register_to_config

* All models and test
parent 7e11392d
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" ConfigMixinuration base class and utilities.""" """ ConfigMixinuration base class and utilities."""
import functools
import inspect import inspect
import json import json
import os import os
...@@ -295,3 +296,46 @@ class FrozenDict(OrderedDict): ...@@ -295,3 +296,46 @@ class FrozenDict(OrderedDict):
if hasattr(self, "__frozen") and self.__frozen: if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value) super().__setitem__(name, value)
def register_to_config(init):
"""
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@functools.wraps(init)
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`."
)
ignore = getattr(self, "ignore_for_config", [])
# Get positional arguments aligned with kwargs
new_kwargs = {}
signature = inspect.signature(init)
parameters = {
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
}
for arg, name in zip(args, parameters.keys()):
new_kwargs[name] = arg
# Then add all kwargs
new_kwargs.update(
{
k: init_kwargs.get(k, default)
for k, default in parameters.items()
if k not in ignore and k not in new_kwargs
}
)
getattr(self, "register_to_config")(**new_kwargs)
return inner_init
...@@ -3,7 +3,7 @@ from typing import Dict, Union ...@@ -3,7 +3,7 @@ from typing import Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
...@@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
increased efficiency. increased efficiency.
""" """
@register_to_config
def __init__( def __init__(
self, self,
image_size=None, image_size=None,
...@@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1, mid_block_scale_factor=1,
center_input_sample=False, center_input_sample=False,
resnet_num_groups=30, resnet_num_groups=30,
**kwargs,
): ):
super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
block_channels=block_channels,
downsample_padding=downsample_padding,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
down_blocks=down_blocks,
up_blocks=up_blocks,
dropout=dropout,
resnet_eps=resnet_eps,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample,
)
self.image_size = image_size self.image_size = image_size
time_embed_dim = block_channels[0] * 4 time_embed_dim = block_channels[0] * 4
......
...@@ -3,7 +3,7 @@ from typing import Dict, Union ...@@ -3,7 +3,7 @@ from typing import Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
...@@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
increased efficiency. increased efficiency.
""" """
@register_to_config
def __init__( def __init__(
self, self,
image_size=None, image_size=None,
...@@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1, mid_block_scale_factor=1,
center_input_sample=False, center_input_sample=False,
resnet_num_groups=32, resnet_num_groups=32,
**kwargs,
): ):
super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
block_channels=block_channels,
downsample_padding=downsample_padding,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
down_blocks=down_blocks,
up_blocks=up_blocks,
dropout=dropout,
resnet_eps=resnet_eps,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
time_embedding_type=time_embedding_type,
mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample,
)
self.image_size = image_size self.image_size = image_size
time_embed_dim = block_channels[0] * 4 time_embed_dim = block_channels[0] * 4
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
...@@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object): ...@@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object):
class VQModel(ModelMixin, ConfigMixin): class VQModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__( def __init__(
self, self,
ch, ch,
...@@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin):
resamp_with_conv=True, resamp_with_conv=True,
give_pre_end=False, give_pre_end=False,
): ):
super().__init__()
# register all __init__ params with self.register
self.register_to_config(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
n_embed=n_embed,
embed_dim=embed_dim,
remap=remap,
sane_index_shape=sane_index_shape,
ch_mult=ch_mult,
dropout=dropout,
double_z=double_z,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
)
# pass init params to Encoder # pass init params to Encoder
self.encoder = Encoder( self.encoder = Encoder(
...@@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin):
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__( def __init__(
self, self,
ch, ch,
...@@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
resamp_with_conv=True, resamp_with_conv=True,
give_pre_end=False, give_pre_end=False,
): ):
super().__init__()
# register all __init__ params with self.register
self.register_to_config(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
embed_dim=embed_dim,
remap=remap,
sane_index_shape=sane_index_shape,
ch_mult=ch_mult,
dropout=dropout,
double_z=double_z,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
)
# pass init params to Encoder # pass init params to Encoder
self.encoder = Encoder( self.encoder = Encoder(
......
...@@ -21,7 +21,7 @@ from typing import Union ...@@ -21,7 +21,7 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDIMScheduler(SchedulerMixin, ConfigMixin): class DDIMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=1000, num_train_timesteps=1000,
...@@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample=True, clip_sample=True,
tensor_format="np", tensor_format="np",
): ):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
clip_sample=clip_sample,
)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
......
...@@ -20,7 +20,7 @@ from typing import Union ...@@ -20,7 +20,7 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDPMScheduler(SchedulerMixin, ConfigMixin): class DDPMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=1000, num_train_timesteps=1000,
...@@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_sample=True, clip_sample=True,
tensor_format="np", tensor_format="np",
): ):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
variance_type=variance_type,
clip_sample=clip_sample,
)
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = np.asarray(trained_betas)
......
...@@ -20,7 +20,7 @@ from typing import Union ...@@ -20,7 +20,7 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class PNDMScheduler(SchedulerMixin, ConfigMixin): class PNDMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=1000, num_train_timesteps=1000,
...@@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule="linear", beta_schedule="linear",
tensor_format="np", tensor_format="np",
): ):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
......
...@@ -21,7 +21,7 @@ from typing import Union ...@@ -21,7 +21,7 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"np" or "pt" for the expected format of samples passed to the Scheduler. "np" or "pt" for the expected format of samples passed to the Scheduler.
""" """
@register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=2000, num_train_timesteps=2000,
...@@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps=1, correct_steps=1,
tensor_format="pt", tensor_format="pt",
): ):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
snr=snr,
sigma_min=sigma_min,
sigma_max=sigma_max,
sampling_eps=sampling_eps,
correct_steps=correct_steps,
)
# self.sigmas = None # self.sigmas = None
# self.discrete_sigmas = None # self.discrete_sigmas = None
# #
......
...@@ -19,19 +19,13 @@ ...@@ -19,19 +19,13 @@
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_min=beta_min,
beta_max=beta_max,
sampling_eps=sampling_eps,
)
self.sigmas = None self.sigmas = None
self.discrete_sigmas = None self.discrete_sigmas = None
......
...@@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json" ...@@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
class SchedulerMixin: class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
ignore_for_config = ["tensor_format"]
def set_format(self, tensor_format="pt"): def set_format(self, tensor_format="pt"):
self.tensor_format = tensor_format self.tensor_format = tensor_format
......
...@@ -18,6 +18,7 @@ import inspect ...@@ -18,6 +18,7 @@ import inspect
import math import math
import tempfile import tempfile
import unittest import unittest
from atexit import register
import numpy as np import numpy as np
import torch import torch
...@@ -38,7 +39,7 @@ from diffusers import ( ...@@ -38,7 +39,7 @@ from diffusers import (
UNetUnconditionalModel, UNetUnconditionalModel,
VQModel, VQModel,
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
...@@ -47,25 +48,63 @@ from diffusers.training_utils import EMAModel ...@@ -47,25 +48,63 @@ from diffusers.training_utils import EMAModel
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class SampleObject(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
):
pass
class ConfigTester(unittest.TestCase): class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self): def test_load_not_from_mixin(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ConfigMixin.from_config("dummy_path") ConfigMixin.from_config("dummy_path")
def test_save_load(self): def test_register_to_config(self):
class SampleObject(ConfigMixin): obj = SampleObject()
config_name = "config.json" config = obj.config
assert config["a"] == 2
def __init__( assert config["b"] == 5
self, assert config["c"] == (2, 5)
a=2, assert config["d"] == "for diffusion"
b=5, assert config["e"] == [1, 3]
c=(2, 5),
d="for diffusion",
e=[1, 3],
):
self.register_to_config(a=a, b=b, c=c, d=d, e=e)
# init ignore private arguments
obj = SampleObject(_name_or_path="lalala")
config = obj.config
assert config["a"] == 2
assert config["b"] == 5
assert config["c"] == (2, 5)
assert config["d"] == "for diffusion"
assert config["e"] == [1, 3]
# can override default
obj = SampleObject(c=6)
config = obj.config
assert config["a"] == 2
assert config["b"] == 5
assert config["c"] == 6
assert config["d"] == "for diffusion"
assert config["e"] == [1, 3]
# can use positional arguments.
obj = SampleObject(1, c=6)
config = obj.config
assert config["a"] == 1
assert config["b"] == 5
assert config["c"] == 6
assert config["d"] == "for diffusion"
assert config["e"] == [1, 3]
def test_save_load(self):
obj = SampleObject() obj = SampleObject()
config = obj.config config = obj.config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment