Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
ad9d2525
Unverified
Commit
ad9d2525
authored
Jul 20, 2022
by
Sylvain Gugger
Committed by
GitHub
Jul 20, 2022
Browse files
Add a decorator for register_to_config (#108)
* Add a decorator for register_to_config * All models and test
parent
7e11392d
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
115 additions
and
174 deletions
+115
-174
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+44
-0
src/diffusers/models/unet_conditional.py
src/diffusers/models/unet_conditional.py
+2
-34
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+2
-35
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+3
-42
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+2
-11
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-12
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+2
-8
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+2
-10
src/diffusers/schedulers/scheduling_sde_vp.py
src/diffusers/schedulers/scheduling_sde_vp.py
+2
-8
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+1
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+53
-14
No files found.
src/diffusers/configuration_utils.py
View file @
ad9d2525
...
...
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
import
functools
import
inspect
import
json
import
os
...
...
@@ -295,3 +296,46 @@ class FrozenDict(OrderedDict):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
super
().
__setitem__
(
name
,
value
)
def
register_to_config
(
init
):
"""
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@
functools
.
wraps
(
init
)
def
inner_init
(
self
,
*
args
,
**
kwargs
):
# Ignore private kwargs in the init.
init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"_"
)}
init
(
self
,
*
args
,
**
init_kwargs
)
if
not
isinstance
(
self
,
ConfigMixin
):
raise
RuntimeError
(
f
"`@register_for_config` was applied to
{
self
.
__class__
.
__name__
}
init method, but this class does "
"not inherit from `ConfigMixin`."
)
ignore
=
getattr
(
self
,
"ignore_for_config"
,
[])
# Get positional arguments aligned with kwargs
new_kwargs
=
{}
signature
=
inspect
.
signature
(
init
)
parameters
=
{
name
:
p
.
default
for
i
,
(
name
,
p
)
in
enumerate
(
signature
.
parameters
.
items
())
if
i
>
0
and
name
not
in
ignore
}
for
arg
,
name
in
zip
(
args
,
parameters
.
keys
()):
new_kwargs
[
name
]
=
arg
# Then add all kwargs
new_kwargs
.
update
(
{
k
:
init_kwargs
.
get
(
k
,
default
)
for
k
,
default
in
parameters
.
items
()
if
k
not
in
ignore
and
k
not
in
new_kwargs
}
)
getattr
(
self
,
"register_to_config"
)(
**
new_kwargs
)
return
inner_init
src/diffusers/models/unet_conditional.py
View file @
ad9d2525
...
...
@@ -3,7 +3,7 @@ from typing import Dict, Union
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
TimestepEmbedding
,
Timesteps
from
.unet_blocks
import
UNetMidBlock2DCrossAttn
,
get_down_block
,
get_up_block
...
...
@@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
"""
@
register_to_config
def
__init__
(
self
,
image_size
=
None
,
...
...
@@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
resnet_num_groups
=
30
,
**
kwargs
,
):
super
().
__init__
()
# remove automatically added kwargs
for
arg
in
self
.
_automatically_saved_args
:
kwargs
.
pop
(
arg
,
None
)
if
len
(
kwargs
)
>
0
:
raise
ValueError
(
f
"The following keyword arguments do not exist for
{
self
.
__class__
}
:
{
','
.
join
(
kwargs
.
keys
())
}
"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_channels
=
block_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
mid_block_scale_factor
=
mid_block_scale_factor
,
resnet_num_groups
=
resnet_num_groups
,
center_input_sample
=
center_input_sample
,
)
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
...
...
src/diffusers/models/unet_unconditional.py
View file @
ad9d2525
...
...
@@ -3,7 +3,7 @@ from typing import Dict, Union
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
.unet_blocks
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
...
...
@@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
"""
@
register_to_config
def
__init__
(
self
,
image_size
=
None
,
...
...
@@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
resnet_num_groups
=
32
,
**
kwargs
,
):
super
().
__init__
()
# remove automatically added kwargs
for
arg
in
self
.
_automatically_saved_args
:
kwargs
.
pop
(
arg
,
None
)
if
len
(
kwargs
)
>
0
:
raise
ValueError
(
f
"The following keyword arguments do not exist for
{
self
.
__class__
}
:
{
','
.
join
(
kwargs
.
keys
())
}
"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_channels
=
block_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
time_embedding_type
=
time_embedding_type
,
mid_block_scale_factor
=
mid_block_scale_factor
,
resnet_num_groups
=
resnet_num_groups
,
center_input_sample
=
center_input_sample
,
)
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
...
...
src/diffusers/models/vae.py
View file @
ad9d2525
...
...
@@ -2,7 +2,7 @@ import numpy as np
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
...
...
@@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object):
class
VQModel
(
ModelMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
ch
,
...
...
@@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin):
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
...
...
@@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin):
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
ch
,
...
...
@@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
ad9d2525
...
...
@@ -21,7 +21,7 @@ from typing import Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
DDIMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
1000
,
...
...
@@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample
=
True
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
trained_betas
=
trained_betas
,
timestep_values
=
timestep_values
,
clip_sample
=
clip_sample
,
)
if
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
np
.
float32
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
ad9d2525
...
...
@@ -20,7 +20,7 @@ from typing import Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
DDPMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
1000
,
...
...
@@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_sample
=
True
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
trained_betas
=
trained_betas
,
timestep_values
=
timestep_values
,
variance_type
=
variance_type
,
clip_sample
=
clip_sample
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
np
.
asarray
(
trained_betas
)
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
ad9d2525
...
...
@@ -20,7 +20,7 @@ from typing import Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
PNDMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
1000
,
...
...
@@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule
=
"linear"
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
)
if
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
np
.
float32
)
...
...
src/diffusers/schedulers/scheduling_sde_ve.py
View file @
ad9d2525
...
...
@@ -21,7 +21,7 @@ from typing import Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"np" or "pt" for the expected format of samples passed to the Scheduler.
"""
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
2000
,
...
...
@@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps
=
1
,
tensor_format
=
"pt"
,
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
snr
=
snr
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
sampling_eps
=
sampling_eps
,
correct_steps
=
correct_steps
,
)
# self.sigmas = None
# self.discrete_sigmas = None
#
...
...
src/diffusers/schedulers/scheduling_sde_vp.py
View file @
ad9d2525
...
...
@@ -19,19 +19,13 @@
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
class
ScoreSdeVpScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
2000
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
,
tensor_format
=
"np"
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_min
=
beta_min
,
beta_max
=
beta_max
,
sampling_eps
=
sampling_eps
,
)
self
.
sigmas
=
None
self
.
discrete_sigmas
=
None
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
ad9d2525
...
...
@@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
class
SchedulerMixin
:
config_name
=
SCHEDULER_CONFIG_NAME
ignore_for_config
=
[
"tensor_format"
]
def
set_format
(
self
,
tensor_format
=
"pt"
):
self
.
tensor_format
=
tensor_format
...
...
tests/test_modeling_utils.py
View file @
ad9d2525
...
...
@@ -18,6 +18,7 @@ import inspect
import
math
import
tempfile
import
unittest
from
atexit
import
register
import
numpy
as
np
import
torch
...
...
@@ -38,7 +39,7 @@ from diffusers import (
UNetUnconditionalModel
,
VQModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
...
...
@@ -47,25 +48,63 @@ from diffusers.training_utils import EMAModel
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
SampleObject
(
ConfigMixin
):
config_name
=
"config.json"
@
register_to_config
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
pass
class
ConfigTester
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from_config
(
"dummy_path"
)
def
test_save_load
(
self
):
class
SampleObject
(
ConfigMixin
):
config_name
=
"config.json"
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
self
.
register_to_config
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
def
test_register_to_config
(
self
):
obj
=
SampleObject
()
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# init ignore private arguments
obj
=
SampleObject
(
_name_or_path
=
"lalala"
)
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# can override default
obj
=
SampleObject
(
c
=
6
)
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
6
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# can use positional arguments.
obj
=
SampleObject
(
1
,
c
=
6
)
config
=
obj
.
config
assert
config
[
"a"
]
==
1
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
6
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
def
test_save_load
(
self
):
obj
=
SampleObject
()
config
=
obj
.
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment