Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ad9d2525
Unverified
Commit
ad9d2525
authored
Jul 20, 2022
by
Sylvain Gugger
Committed by
GitHub
Jul 20, 2022
Browse files
Add a decorator for register_to_config (#108)
* Add a decorator for register_to_config * All models and test
parent
7e11392d
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
115 additions
and
174 deletions
+115
-174
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+44
-0
src/diffusers/models/unet_conditional.py
src/diffusers/models/unet_conditional.py
+2
-34
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+2
-35
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+3
-42
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+2
-11
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-12
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+2
-8
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+2
-10
src/diffusers/schedulers/scheduling_sde_vp.py
src/diffusers/schedulers/scheduling_sde_vp.py
+2
-8
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+1
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+53
-14
No files found.
src/diffusers/configuration_utils.py
View file @
ad9d2525
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
""" ConfigMixinuration base class and utilities."""
import
functools
import
inspect
import
inspect
import
json
import
json
import
os
import
os
...
@@ -295,3 +296,46 @@ class FrozenDict(OrderedDict):
...
@@ -295,3 +296,46 @@ class FrozenDict(OrderedDict):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``__setattr__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
super
().
__setitem__
(
name
,
value
)
super
().
__setitem__
(
name
,
value
)
def
register_to_config
(
init
):
"""
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@
functools
.
wraps
(
init
)
def
inner_init
(
self
,
*
args
,
**
kwargs
):
# Ignore private kwargs in the init.
init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"_"
)}
init
(
self
,
*
args
,
**
init_kwargs
)
if
not
isinstance
(
self
,
ConfigMixin
):
raise
RuntimeError
(
f
"`@register_for_config` was applied to
{
self
.
__class__
.
__name__
}
init method, but this class does "
"not inherit from `ConfigMixin`."
)
ignore
=
getattr
(
self
,
"ignore_for_config"
,
[])
# Get positional arguments aligned with kwargs
new_kwargs
=
{}
signature
=
inspect
.
signature
(
init
)
parameters
=
{
name
:
p
.
default
for
i
,
(
name
,
p
)
in
enumerate
(
signature
.
parameters
.
items
())
if
i
>
0
and
name
not
in
ignore
}
for
arg
,
name
in
zip
(
args
,
parameters
.
keys
()):
new_kwargs
[
name
]
=
arg
# Then add all kwargs
new_kwargs
.
update
(
{
k
:
init_kwargs
.
get
(
k
,
default
)
for
k
,
default
in
parameters
.
items
()
if
k
not
in
ignore
and
k
not
in
new_kwargs
}
)
getattr
(
self
,
"register_to_config"
)(
**
new_kwargs
)
return
inner_init
src/diffusers/models/unet_conditional.py
View file @
ad9d2525
...
@@ -3,7 +3,7 @@ from typing import Dict, Union
...
@@ -3,7 +3,7 @@ from typing import Dict, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
TimestepEmbedding
,
Timesteps
from
.embeddings
import
TimestepEmbedding
,
Timesteps
from
.unet_blocks
import
UNetMidBlock2DCrossAttn
,
get_down_block
,
get_up_block
from
.unet_blocks
import
UNetMidBlock2DCrossAttn
,
get_down_block
,
get_up_block
...
@@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
...
@@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
increased efficiency.
"""
"""
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
image_size
=
None
,
image_size
=
None
,
...
@@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
...
@@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor
=
1
,
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
center_input_sample
=
False
,
resnet_num_groups
=
30
,
resnet_num_groups
=
30
,
**
kwargs
,
):
):
super
().
__init__
()
# remove automatically added kwargs
for
arg
in
self
.
_automatically_saved_args
:
kwargs
.
pop
(
arg
,
None
)
if
len
(
kwargs
)
>
0
:
raise
ValueError
(
f
"The following keyword arguments do not exist for
{
self
.
__class__
}
:
{
','
.
join
(
kwargs
.
keys
())
}
"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_channels
=
block_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
mid_block_scale_factor
=
mid_block_scale_factor
,
resnet_num_groups
=
resnet_num_groups
,
center_input_sample
=
center_input_sample
,
)
self
.
image_size
=
image_size
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
time_embed_dim
=
block_channels
[
0
]
*
4
...
...
src/diffusers/models/unet_unconditional.py
View file @
ad9d2525
...
@@ -3,7 +3,7 @@ from typing import Dict, Union
...
@@ -3,7 +3,7 @@ from typing import Dict, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
.unet_blocks
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
from
.unet_blocks
import
UNetMidBlock2D
,
get_down_block
,
get_up_block
...
@@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
increased efficiency.
"""
"""
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
image_size
=
None
,
image_size
=
None
,
...
@@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor
=
1
,
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
center_input_sample
=
False
,
resnet_num_groups
=
32
,
resnet_num_groups
=
32
,
**
kwargs
,
):
):
super
().
__init__
()
# remove automatically added kwargs
for
arg
in
self
.
_automatically_saved_args
:
kwargs
.
pop
(
arg
,
None
)
if
len
(
kwargs
)
>
0
:
raise
ValueError
(
f
"The following keyword arguments do not exist for
{
self
.
__class__
}
:
{
','
.
join
(
kwargs
.
keys
())
}
"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_channels
=
block_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
time_embedding_type
=
time_embedding_type
,
mid_block_scale_factor
=
mid_block_scale_factor
,
resnet_num_groups
=
resnet_num_groups
,
center_input_sample
=
center_input_sample
,
)
self
.
image_size
=
image_size
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
time_embed_dim
=
block_channels
[
0
]
*
4
...
...
src/diffusers/models/vae.py
View file @
ad9d2525
...
@@ -2,7 +2,7 @@ import numpy as np
...
@@ -2,7 +2,7 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
from
.resnet
import
Downsample2D
,
ResnetBlock2D
,
Upsample2D
...
@@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object):
...
@@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object):
class
VQModel
(
ModelMixin
,
ConfigMixin
):
class
VQModel
(
ModelMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
ch
,
ch
,
...
@@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin):
...
@@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin):
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
give_pre_end
=
False
,
):
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
# pass init params to Encoder
self
.
encoder
=
Encoder
(
self
.
encoder
=
Encoder
(
...
@@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin):
...
@@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin):
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
ch
,
ch
,
...
@@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
give_pre_end
=
False
,
):
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
# pass init params to Encoder
self
.
encoder
=
Encoder
(
self
.
encoder
=
Encoder
(
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
ad9d2525
...
@@ -21,7 +21,7 @@ from typing import Union
...
@@ -21,7 +21,7 @@ from typing import Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...
@@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
DDIMScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
DDIMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
num_train_timesteps
=
1000
,
num_train_timesteps
=
1000
,
...
@@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample
=
True
,
clip_sample
=
True
,
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
trained_betas
=
trained_betas
,
timestep_values
=
timestep_values
,
clip_sample
=
clip_sample
,
)
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
np
.
float32
)
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
np
.
float32
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
ad9d2525
...
@@ -20,7 +20,7 @@ from typing import Union
...
@@ -20,7 +20,7 @@ from typing import Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
DDPMScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
DDPMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
num_train_timesteps
=
1000
,
num_train_timesteps
=
1000
,
...
@@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_sample
=
True
,
clip_sample
=
True
,
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
trained_betas
=
trained_betas
,
timestep_values
=
timestep_values
,
variance_type
=
variance_type
,
clip_sample
=
clip_sample
,
)
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
np
.
asarray
(
trained_betas
)
self
.
betas
=
np
.
asarray
(
trained_betas
)
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
ad9d2525
...
@@ -20,7 +20,7 @@ from typing import Union
...
@@ -20,7 +20,7 @@ from typing import Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class
PNDMScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
PNDMScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
num_train_timesteps
=
1000
,
num_train_timesteps
=
1000
,
...
@@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule
=
"linear"
,
beta_schedule
=
"linear"
,
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
)
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
np
.
float32
)
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
np
.
float32
)
...
...
src/diffusers/schedulers/scheduling_sde_ve.py
View file @
ad9d2525
...
@@ -21,7 +21,7 @@ from typing import Union
...
@@ -21,7 +21,7 @@ from typing import Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"np" or "pt" for the expected format of samples passed to the Scheduler.
"np" or "pt" for the expected format of samples passed to the Scheduler.
"""
"""
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
num_train_timesteps
=
2000
,
num_train_timesteps
=
2000
,
...
@@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps
=
1
,
correct_steps
=
1
,
tensor_format
=
"pt"
,
tensor_format
=
"pt"
,
):
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
snr
=
snr
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
sampling_eps
=
sampling_eps
,
correct_steps
=
correct_steps
,
)
# self.sigmas = None
# self.sigmas = None
# self.discrete_sigmas = None
# self.discrete_sigmas = None
#
#
...
...
src/diffusers/schedulers/scheduling_sde_vp.py
View file @
ad9d2525
...
@@ -19,19 +19,13 @@
...
@@ -19,19 +19,13 @@
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
class
ScoreSdeVpScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
ScoreSdeVpScheduler
(
SchedulerMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
2000
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
,
tensor_format
=
"np"
):
def
__init__
(
self
,
num_train_timesteps
=
2000
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
,
tensor_format
=
"np"
):
super
().
__init__
()
self
.
register_to_config
(
num_train_timesteps
=
num_train_timesteps
,
beta_min
=
beta_min
,
beta_max
=
beta_max
,
sampling_eps
=
sampling_eps
,
)
self
.
sigmas
=
None
self
.
sigmas
=
None
self
.
discrete_sigmas
=
None
self
.
discrete_sigmas
=
None
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
ad9d2525
...
@@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
...
@@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
class
SchedulerMixin
:
class
SchedulerMixin
:
config_name
=
SCHEDULER_CONFIG_NAME
config_name
=
SCHEDULER_CONFIG_NAME
ignore_for_config
=
[
"tensor_format"
]
def
set_format
(
self
,
tensor_format
=
"pt"
):
def
set_format
(
self
,
tensor_format
=
"pt"
):
self
.
tensor_format
=
tensor_format
self
.
tensor_format
=
tensor_format
...
...
tests/test_modeling_utils.py
View file @
ad9d2525
...
@@ -18,6 +18,7 @@ import inspect
...
@@ -18,6 +18,7 @@ import inspect
import
math
import
math
import
tempfile
import
tempfile
import
unittest
import
unittest
from
atexit
import
register
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -38,7 +39,7 @@ from diffusers import (
...
@@ -38,7 +39,7 @@ from diffusers import (
UNetUnconditionalModel
,
UNetUnconditionalModel
,
VQModel
,
VQModel
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
...
@@ -47,15 +48,10 @@ from diffusers.training_utils import EMAModel
...
@@ -47,15 +48,10 @@ from diffusers.training_utils import EMAModel
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
ConfigTester
(
unittest
.
TestCase
):
class
SampleObject
(
ConfigMixin
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from_config
(
"dummy_path"
)
def
test_save_load
(
self
):
class
SampleObject
(
ConfigMixin
):
config_name
=
"config.json"
config_name
=
"config.json"
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
a
=
2
,
a
=
2
,
...
@@ -64,8 +60,51 @@ class ConfigTester(unittest.TestCase):
...
@@ -64,8 +60,51 @@ class ConfigTester(unittest.TestCase):
d
=
"for diffusion"
,
d
=
"for diffusion"
,
e
=
[
1
,
3
],
e
=
[
1
,
3
],
):
):
self
.
register_to_config
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
pass
class
ConfigTester
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from_config
(
"dummy_path"
)
def
test_register_to_config
(
self
):
obj
=
SampleObject
()
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# init ignore private arguments
obj
=
SampleObject
(
_name_or_path
=
"lalala"
)
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# can override default
obj
=
SampleObject
(
c
=
6
)
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
6
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
# can use positional arguments.
obj
=
SampleObject
(
1
,
c
=
6
)
config
=
obj
.
config
assert
config
[
"a"
]
==
1
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
6
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
def
test_save_load
(
self
):
obj
=
SampleObject
()
obj
=
SampleObject
()
config
=
obj
.
config
config
=
obj
.
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment