Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ad9d2525
Unverified
Commit
ad9d2525
authored
Jul 20, 2022
by
Sylvain Gugger
Committed by
GitHub
Jul 20, 2022
Browse files
Add a decorator for register_to_config (#108)
* Add a decorator for register_to_config * All models and test
parent
7e11392d
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
115 additions
and
174 deletions
+115
-174
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+44
-0
src/diffusers/models/unet_conditional.py
src/diffusers/models/unet_conditional.py
+2
-34
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+2
-35
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+3
-42
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+2
-11
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-12
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+2
-8
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+2
-10
src/diffusers/schedulers/scheduling_sde_vp.py
src/diffusers/schedulers/scheduling_sde_vp.py
+2
-8
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+1
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+53
-14
No files found.
src/diffusers/configuration_utils.py
View file @
ad9d2525
...
...
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
import
functools
import
inspect
import
json
import
os
...
...
@@ -295,3 +296,46 @@ class FrozenDict(OrderedDict):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
super
().
__setitem__
(
name
,
value
)
def
register_to_config
(
init
):
"""
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@
functools
.
wraps
(
init
)
def
inner_init
(
self
,
*
args
,
**
kwargs
):
# Ignore private kwargs in the init.
init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"_"
)}
init
(
self
,
*
args
,
**
init_kwargs
)
if
not
isinstance
(
self
,
ConfigMixin
):
raise
RuntimeError
(
f
"`@register_for_config` was applied to
{
self
.
__class__
.
__name__
}
init method, but this class does "
"not inherit from `ConfigMixin`."
)
ignore
=
getattr
(
self
,
"ignore_for_config"
,
[])
# Get positional arguments aligned with kwargs
new_kwargs
=
{}
signature
=
inspect
.
signature
(
init
)
parameters
=
{
name
:
p
.
default
for
i
,
(
name
,
p
)
in
enumerate
(
signature
.
parameters
.
items
())
if
i
>
0
and
name
not
in
ignore
}
for
arg
,
name
in
zip
(
args
,
parameters
.
keys
()):
new_kwargs
[
name
]
=
arg
# Then add all kwargs
new_kwargs
.
update
(
{
k
:
init_kwargs
.
get
(
k
,
default
)
for
k
,
default
in
parameters
.
items
()
if
k
not
in
ignore
and
k
not
in
new_kwargs
}
)
getattr
(
self
,
"register_to_config"
)(
**
new_kwargs
)
return
inner_init
src/diffusers/models/unet_conditional.py
View file @
ad9d2525
...
...
@@ -3,7 +3,7 @@ from typing import Dict, Union
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
TimestepEmbedding
,
Timesteps
from
.unet_blocks
import
UNetMidBlock2DCrossAttn
,
get_down_block
,
get_up_block
...
...
@@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
"""
@
register_to_config
def
__init__
(
self
,
image_size
=
None
,
...
...
@@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
resnet_num_groups
=
30
,
**
kwargs
,
):
super
().
__init__
()
# remove automatically added kwargs
for
arg
in
self
.
_automatically_saved_args
:
kwargs
.
pop
(
arg
,
None
)
if
len
(
kwargs
)
>
0
:
raise
ValueError
(
f
"The following keyword arguments do not exist for
{
self
.
__class__
}
:
{
','
.
join
(
kwargs
.
keys
())
}
"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_channels
=
block_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
mid_block_scale_factor
=
mid_block_scale_factor
,
resnet_num_groups
=
resnet_num_groups
,
center_input_sample
=
center_input_sample
,
)
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
...
...
src/diffusers/models/unet_unconditional.py
View file @
ad9d2525
...
...
@@ -3,7 +3,7 @@ from typing import Dict, Union
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
.unet_blocks
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
...
...
@@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
"""
@
register_to_config
def
__init__
(
self
,
image_size
=
None
,
...
...
@@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
resnet_num_groups
=
32
,
**
kwargs
,
):
super
().
__init__
()
# remove automatically added kwargs
for
arg
in
self
.
_automatically_saved_args
:
kwargs
.
pop
(
arg
,
None
)
if
len
(
kwargs
)
>
0
:
raise
ValueError
(
f
"The following keyword arguments do not exist for
{
self
.
__class__
}
:
{
','
.
join
(
kwargs
.
keys
())
}
"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_channels
=
block_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
time_embedding_type
=
time_embedding_type
,
mid_block_scale_factor
=
mid_block_scale_factor
,
resnet_num_groups
=
resnet_num_groups
,
center_input_sample
=
center_input_sample
,
)
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
...
...
src/diffusers/models/vae.py
View file @
ad9d2525
...
...
@@ -2,7 +2,7 @@ import numpy as np
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
...
...
@@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object):
class
VQModel
(
ModelMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
ch
,
...
...
@@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin):
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
...
...
@@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin):
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
ch
,
...
...
@@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
ad9d2525
...
...
@@ -21,7 +21,7 @@ from typing import Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
DDIMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
1000
,
...
...
@@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample
=
True
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
trained_betas
=
trained_betas
,
timestep_values
=
timestep_values
,
clip_sample
=
clip_sample
,
)
if
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
np
.
float32
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
ad9d2525
...
...
@@ -20,7 +20,7 @@ from typing import Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
DDPMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
1000
,
...
...
@@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_sample
=
True
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
trained_betas
=
trained_betas
,
timestep_values
=
timestep_values
,
variance_type
=
variance_type
,
clip_sample
=
clip_sample
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
np
.
asarray
(
trained_betas
)
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
ad9d2525
...
...
@@ -20,7 +20,7 @@ from typing import Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
PNDMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
1000
,
...
...
@@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule
=
"linear"
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
)
if
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
np
.
float32
)
...
...
src/diffusers/schedulers/scheduling_sde_ve.py
View file @
ad9d2525
...
...
@@ -21,7 +21,7 @@ from typing import Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"np" or "pt" for the expected format of samples passed to the Scheduler.
"""
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
2000
,
...
...
@@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps
=
1
,
tensor_format
=
"pt"
,
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
snr
=
snr
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
sampling_eps
=
sampling_eps
,
correct_steps
=
correct_steps
,
)
# self.sigmas = None
# self.discrete_sigmas = None
#
...
...
src/diffusers/schedulers/scheduling_sde_vp.py
View file @
ad9d2525
...
...
@@ -19,19 +19,13 @@
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
class
ScoreSdeVpScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
2000
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
,
tensor_format
=
"np"
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_min
=
beta_min
,
beta_max
=
beta_max
,
sampling_eps
=
sampling_eps
,
)
self
.
sigmas
=
None
self
.
discrete_sigmas
=
None
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
ad9d2525
...
...
@@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
class
SchedulerMixin
:
config_name
=
SCHEDULER_CONFIG_NAME
ignore_for_config
=
[
"tensor_format"
]
def
set_format
(
self
,
tensor_format
=
"pt"
):
self
.
tensor_format
=
tensor_format
...
...
tests/test_modeling_utils.py
View file @
ad9d2525
...
...
@@ -18,6 +18,7 @@ import inspect
import
math
import
tempfile
import
unittest
from
atexit
import
register
import
numpy
as
np
import
torch
...
...
@@ -38,7 +39,7 @@ from diffusers import (
UNetUnconditionalModel
,
VQModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
...
...
@@ -47,25 +48,63 @@ from diffusers.training_utils import EMAModel
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
SampleObject
(
ConfigMixin
):
config_name
=
"config.json"
@
register_to_config
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
pass
class
ConfigTester
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from_config
(
"dummy_path"
)
def
test_save_load
(
self
):
class
SampleObject
(
ConfigMixin
):
config_name
=
"config.json"
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
self
.
register_to_config
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
def
test_register_to_config
(
self
):
obj
=
SampleObject
()
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# init ignore private arguments
obj
=
SampleObject
(
_name_or_path
=
"lalala"
)
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# can override default
obj
=
SampleObject
(
c
=
6
)
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
6
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# can use positional arguments.
obj
=
SampleObject
(
1
,
c
=
6
)
config
=
obj
.
config
assert
config
[
"a"
]
==
1
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
6
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
def
test_save_load
(
self
):
obj
=
SampleObject
()
config
=
obj
.
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment