Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
554b374d
Commit
554b374d
authored
Nov 15, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d5ab55e4
a0520193
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
530 additions
and
70 deletions
+530
-70
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+4
-11
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+4
-11
src/diffusers/schedulers/scheduling_ipndm.py
src/diffusers/schedulers/scheduling_ipndm.py
+2
-2
src/diffusers/schedulers/scheduling_karras_ve.py
src/diffusers/schedulers/scheduling_karras_ve.py
+2
-2
src/diffusers/schedulers/scheduling_karras_ve_flax.py
src/diffusers/schedulers/scheduling_karras_ve_flax.py
+2
-2
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+4
-11
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+10
-3
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+4
-10
src/diffusers/schedulers/scheduling_pndm_flax.py
src/diffusers/schedulers/scheduling_pndm_flax.py
+10
-3
src/diffusers/schedulers/scheduling_repaint.py
src/diffusers/schedulers/scheduling_repaint.py
+2
-2
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+2
-2
src/diffusers/schedulers/scheduling_sde_ve_flax.py
src/diffusers/schedulers/scheduling_sde_ve_flax.py
+2
-2
src/diffusers/schedulers/scheduling_sde_vp.py
src/diffusers/schedulers/scheduling_sde_vp.py
+2
-2
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+111
-0
src/diffusers/schedulers/scheduling_utils_flax.py
src/diffusers/schedulers/scheduling_utils_flax.py
+119
-2
src/diffusers/schedulers/scheduling_vq_diffusion.py
src/diffusers/schedulers/scheduling_vq_diffusion.py
+2
-2
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+10
-0
tests/models/test_models_unet_1d.py
tests/models/test_models_unet_1d.py
+233
-2
tests/pipelines/dance_diffusion/test_dance_diffusion.py
tests/pipelines/dance_diffusion/test_dance_diffusion.py
+4
-0
tests/pipelines/ddim/test_ddim.py
tests/pipelines/ddim/test_ddim.py
+1
-1
No files found.
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
554b374d
...
@@ -19,7 +19,7 @@ import numpy as np
...
@@ -19,7 +19,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
logging
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
,
logging
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"DDIMScheduler"
,
"DDPMScheduler"
,
"LMSDiscreteScheduler"
,
"PNDMScheduler"
,
"EulerDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
554b374d
...
@@ -19,7 +19,7 @@ import numpy as np
...
@@ -19,7 +19,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
logging
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
,
logging
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"DDIMScheduler"
,
"DDPMScheduler"
,
"LMSDiscreteScheduler"
,
"PNDMScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_ipndm.py
View file @
554b374d
...
@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
For more details, see the original paper: https://arxiv.org/abs/2202.09778
...
...
src/diffusers/schedulers/scheduling_karras_ve.py
View file @
554b374d
...
@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
...
...
src/diffusers/schedulers/scheduling_karras_ve_flax.py
View file @
554b374d
...
@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
554b374d
...
@@ -21,7 +21,7 @@ import torch
...
@@ -21,7 +21,7 @@ import torch
from
scipy
import
integrate
from
scipy
import
integrate
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"DDIMScheduler"
,
"DDPMScheduler"
,
"PNDMScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
View file @
554b374d
...
@@ -20,7 +20,12 @@ import jax.numpy as jnp
...
@@ -20,7 +20,12 @@ import jax.numpy as jnp
from
scipy
import
integrate
from
scipy
import
integrate
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
@
flax
.
struct
.
dataclass
@
flax
.
struct
.
dataclass
...
@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
return
True
return
True
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
554b374d
...
@@ -21,6 +21,7 @@ import numpy as np
...
@@ -21,6 +21,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
...
@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
For more details, see the original paper: https://arxiv.org/abs/2202.09778
...
@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"DDIMScheduler"
,
"DDPMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_pndm_flax.py
View file @
554b374d
...
@@ -23,7 +23,12 @@ import jax
...
@@ -23,7 +23,12 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
:
int
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
def
betas_for_alpha_bar
(
num_diffusion_timesteps
:
int
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
...
@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
For more details, see the original paper: https://arxiv.org/abs/2202.09778
...
@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion.
stable diffusion.
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
return
True
return
True
...
...
src/diffusers/schedulers/scheduling_repaint.py
View file @
554b374d
...
@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
...
@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
...
...
src/diffusers/schedulers/scheduling_sde_ve.py
View file @
554b374d
...
@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
...
src/diffusers/schedulers/scheduling_sde_ve_flax.py
View file @
554b374d
...
@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
...
src/diffusers/schedulers/scheduling_sde_vp.py
View file @
554b374d
...
@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
...
@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more information, see the original paper: https://arxiv.org/abs/2011.13456
For more information, see the original paper: https://arxiv.org/abs/2011.13456
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
554b374d
...
@@ -11,7 +11,10 @@
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
importlib
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
torch
import
torch
...
@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput):
...
@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput):
class
SchedulerMixin
:
class
SchedulerMixin
:
"""
"""
Mixin containing common functions for the schedulers.
Mixin containing common functions for the schedulers.
Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
"""
"""
config_name
=
SCHEDULER_CONFIG_NAME
config_name
=
SCHEDULER_CONFIG_NAME
_compatibles
=
[]
has_compatibles
=
True
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Dict
[
str
,
Any
]
=
None
,
subfolder
:
Optional
[
str
]
=
None
,
return_unused_kwargs
=
False
,
**
kwargs
,
):
r
"""
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing the schedluer configurations saved using
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config
,
kwargs
=
cls
.
load_config
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
subfolder
=
subfolder
,
return_unused_kwargs
=
True
,
**
kwargs
,
)
return
cls
.
from_config
(
config
,
return_unused_kwargs
=
return_unused_kwargs
,
**
kwargs
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~SchedulerMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
self
.
save_config
(
save_directory
=
save_directory
,
push_to_hub
=
push_to_hub
,
**
kwargs
)
@
property
def
compatibles
(
self
):
"""
Returns all schedulers that are compatible with this scheduler
Returns:
`List[SchedulerMixin]`: List of compatible schedulers
"""
return
self
.
_get_compatibles
()
@
classmethod
def
_get_compatibles
(
cls
):
compatible_classes_str
=
list
(
set
([
cls
.
__name__
]
+
cls
.
_compatibles
))
diffusers_library
=
importlib
.
import_module
(
__name__
.
split
(
"."
)[
0
])
compatible_classes
=
[
getattr
(
diffusers_library
,
c
)
for
c
in
compatible_classes_str
if
hasattr
(
diffusers_library
,
c
)
]
return
compatible_classes
src/diffusers/schedulers/scheduling_utils_flax.py
View file @
554b374d
...
@@ -11,15 +11,18 @@
...
@@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
importlib
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Tuple
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..utils
import
BaseOutput
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
=
[
"Flax"
+
c
for
c
in
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
]
@
dataclass
@
dataclass
...
@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput):
...
@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput):
class
FlaxSchedulerMixin
:
class
FlaxSchedulerMixin
:
"""
"""
Mixin containing common functions for the schedulers.
Mixin containing common functions for the schedulers.
Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
"""
"""
config_name
=
SCHEDULER_CONFIG_NAME
config_name
=
SCHEDULER_CONFIG_NAME
_compatibles
=
[]
has_compatibles
=
True
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Dict
[
str
,
Any
]
=
None
,
subfolder
:
Optional
[
str
]
=
None
,
return_unused_kwargs
=
False
,
**
kwargs
,
):
r
"""
Instantiate a Scheduler class from a pre-defined JSON-file.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
e.g., `./my_model_directory/`.
subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config
,
kwargs
=
cls
.
load_config
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
return_unused_kwargs
=
True
,
**
kwargs
)
scheduler
,
unused_kwargs
=
cls
.
from_config
(
config
,
return_unused_kwargs
=
True
,
**
kwargs
)
if
hasattr
(
scheduler
,
"create_state"
)
and
getattr
(
scheduler
,
"has_state"
,
False
):
state
=
scheduler
.
create_state
()
if
return_unused_kwargs
:
return
scheduler
,
state
,
unused_kwargs
return
scheduler
,
state
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~FlaxSchedulerMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
self
.
save_config
(
save_directory
=
save_directory
,
push_to_hub
=
push_to_hub
,
**
kwargs
)
@
property
def
compatibles
(
self
):
"""
Returns all schedulers that are compatible with this scheduler
Returns:
`List[SchedulerMixin]`: List of compatible schedulers
"""
return
self
.
_get_compatibles
()
@
classmethod
def
_get_compatibles
(
cls
):
compatible_classes_str
=
list
(
set
([
cls
.
__name__
]
+
cls
.
_compatibles
))
diffusers_library
=
importlib
.
import_module
(
__name__
.
split
(
"."
)[
0
])
compatible_classes
=
[
getattr
(
diffusers_library
,
c
)
for
c
in
compatible_classes_str
if
hasattr
(
diffusers_library
,
c
)
]
return
compatible_classes
def
broadcast_to_shape_from_left
(
x
:
jnp
.
ndarray
,
shape
:
Tuple
[
int
])
->
jnp
.
ndarray
:
def
broadcast_to_shape_from_left
(
x
:
jnp
.
ndarray
,
shape
:
Tuple
[
int
])
->
jnp
.
ndarray
:
...
...
src/diffusers/schedulers/scheduling_vq_diffusion.py
View file @
554b374d
...
@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
...
@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2111.14822
For more details, see the original paper: https://arxiv.org/abs/2111.14822
...
...
src/diffusers/utils/__init__.py
View file @
554b374d
...
@@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
...
@@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE
=
default_cache_path
DIFFUSERS_CACHE
=
default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME
=
"diffusers_modules"
DIFFUSERS_DYNAMIC_MODULE_NAME
=
"diffusers_modules"
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
=
[
"DDIMScheduler"
,
"DDPMScheduler"
,
"PNDMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
tests/models/test_models_unet_1d.py
View file @
554b374d
...
@@ -18,13 +18,120 @@ import unittest
...
@@ -18,13 +18,120 @@ import unittest
import
torch
import
torch
from
diffusers
import
UNet1DModel
from
diffusers
import
UNet1DModel
from
diffusers.utils
import
slow
,
torch_device
from
diffusers.utils
import
floats_tensor
,
slow
,
torch_device
from
..test_modeling_common
import
ModelTesterMixin
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
UnetModel1DTests
(
unittest
.
TestCase
):
class
UNet1DModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet1DModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
14
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"timestep"
:
time_step
}
@
property
def
input_shape
(
self
):
return
(
4
,
14
,
16
)
@
property
def
output_shape
(
self
):
return
(
4
,
14
,
16
)
def
test_ema_training
(
self
):
pass
def
test_training
(
self
):
pass
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_determinism
(
self
):
super
().
test_determinism
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_outputs_equivalence
(
self
):
super
().
test_outputs_equivalence
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_pretrained_save_pretrained
(
self
):
super
().
test_from_pretrained_save_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_model_from_pretrained
(
self
):
super
().
test_model_from_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_output
(
self
):
super
().
test_output
()
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"block_out_channels"
:
(
32
,
64
,
128
,
256
),
"in_channels"
:
14
,
"out_channels"
:
14
,
"time_embedding_type"
:
"positional"
,
"use_timestep_embedding"
:
True
,
"flip_sin_to_cos"
:
False
,
"freq_shift"
:
1.0
,
"out_block_type"
:
"OutConv1DBlock"
,
"mid_block_type"
:
"MidResTemporalBlock1D"
,
"down_block_types"
:
(
"DownResnetBlock1D"
,
"DownResnetBlock1D"
,
"DownResnetBlock1D"
,
"DownResnetBlock1D"
),
"up_block_types"
:
(
"UpResnetBlock1D"
,
"UpResnetBlock1D"
,
"UpResnetBlock1D"
),
"act_fn"
:
"mish"
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNet1DModel
.
from_pretrained
(
"bglick13/hopper-medium-v2-value-function-hor32"
,
output_loading_info
=
True
,
subfolder
=
"unet"
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_output_pretrained
(
self
):
model
=
UNet1DModel
.
from_pretrained
(
"bglick13/hopper-medium-v2-value-function-hor32"
,
subfolder
=
"unet"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
in_channels
seq_len
=
16
noise
=
torch
.
randn
((
1
,
seq_len
,
num_features
)).
permute
(
0
,
2
,
1
)
# match original, we can update values and remove
time_step
=
torch
.
full
((
num_features
,),
0
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
).
sample
.
permute
(
0
,
2
,
1
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
2.137172
,
1.1426016
,
0.3688687
,
-
0.766922
,
0.7303146
,
0.11038864
,
-
0.4760633
,
0.13270172
,
0.02591348
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-3
))
def
test_forward_with_norm_groups
(
self
):
# Not implemented yet for this UNet
pass
@
slow
@
slow
def
test_unet_1d_maestro
(
self
):
def
test_unet_1d_maestro
(
self
):
model_id
=
"harmonai/maestro-150k"
model_id
=
"harmonai/maestro-150k"
...
@@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase):
...
@@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase):
assert
(
output_sum
-
224.0896
).
abs
()
<
4e-2
assert
(
output_sum
-
224.0896
).
abs
()
<
4e-2
assert
(
output_max
-
0.0607
).
abs
()
<
4e-4
assert
(
output_max
-
0.0607
).
abs
()
<
4e-4
class
UNetRLModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet1DModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
14
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"timestep"
:
time_step
}
@
property
def
input_shape
(
self
):
return
(
4
,
14
,
16
)
@
property
def
output_shape
(
self
):
return
(
4
,
14
,
1
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_determinism
(
self
):
super
().
test_determinism
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_outputs_equivalence
(
self
):
super
().
test_outputs_equivalence
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_pretrained_save_pretrained
(
self
):
super
().
test_from_pretrained_save_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_model_from_pretrained
(
self
):
super
().
test_model_from_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_output
(
self
):
# UNetRL is a value-function is different output shape
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
.
sample
self
.
assertIsNotNone
(
output
)
expected_shape
=
torch
.
Size
((
inputs_dict
[
"sample"
].
shape
[
0
],
1
))
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_ema_training
(
self
):
pass
def
test_training
(
self
):
pass
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"in_channels"
:
14
,
"out_channels"
:
14
,
"down_block_types"
:
[
"DownResnetBlock1D"
,
"DownResnetBlock1D"
,
"DownResnetBlock1D"
,
"DownResnetBlock1D"
],
"up_block_types"
:
[],
"out_block_type"
:
"ValueFunction"
,
"mid_block_type"
:
"ValueFunctionMidBlock1D"
,
"block_out_channels"
:
[
32
,
64
,
128
,
256
],
"layers_per_block"
:
1
,
"downsample_each_block"
:
True
,
"use_timestep_embedding"
:
True
,
"freq_shift"
:
1.0
,
"flip_sin_to_cos"
:
False
,
"time_embedding_type"
:
"positional"
,
"act_fn"
:
"mish"
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_pretrained_hub
(
self
):
value_function
,
vf_loading_info
=
UNet1DModel
.
from_pretrained
(
"bglick13/hopper-medium-v2-value-function-hor32"
,
output_loading_info
=
True
,
subfolder
=
"value_function"
)
self
.
assertIsNotNone
(
value_function
)
self
.
assertEqual
(
len
(
vf_loading_info
[
"missing_keys"
]),
0
)
value_function
.
to
(
torch_device
)
image
=
value_function
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_output_pretrained
(
self
):
value_function
,
vf_loading_info
=
UNet1DModel
.
from_pretrained
(
"bglick13/hopper-medium-v2-value-function-hor32"
,
output_loading_info
=
True
,
subfolder
=
"value_function"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
value_function
.
in_channels
seq_len
=
14
noise
=
torch
.
randn
((
1
,
seq_len
,
num_features
)).
permute
(
0
,
2
,
1
)
# match original, we can update values and remove
time_step
=
torch
.
full
((
num_features
,),
0
)
with
torch
.
no_grad
():
output
=
value_function
(
noise
,
time_step
).
sample
# fmt: off
expected_output_slice
=
torch
.
tensor
([
165.25
]
*
seq_len
)
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output
,
expected_output_slice
,
rtol
=
1e-3
))
def
test_forward_with_norm_groups
(
self
):
# Not implemented yet for this UNet
pass
tests/pipelines/dance_diffusion/test_dance_diffusion.py
View file @
554b374d
...
@@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase):
...
@@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase):
sample_rate
=
16_000
,
sample_rate
=
16_000
,
in_channels
=
2
,
in_channels
=
2
,
out_channels
=
2
,
out_channels
=
2
,
flip_sin_to_cos
=
True
,
use_timestep_embedding
=
False
,
time_embedding_type
=
"fourier"
,
mid_block_type
=
"UNetMidBlock1D"
,
down_block_types
=
[
"DownBlock1DNoSkip"
]
+
[
"DownBlock1D"
]
+
[
"AttnDownBlock1D"
],
down_block_types
=
[
"DownBlock1DNoSkip"
]
+
[
"DownBlock1D"
]
+
[
"AttnDownBlock1D"
],
up_block_types
=
[
"AttnUpBlock1D"
]
+
[
"UpBlock1D"
]
+
[
"UpBlock1DNoSkip"
],
up_block_types
=
[
"AttnUpBlock1D"
]
+
[
"UpBlock1D"
]
+
[
"UpBlock1DNoSkip"
],
)
)
...
...
tests/pipelines/ddim/test_ddim.py
View file @
554b374d
...
@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
...
@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
model_id
=
"google/ddpm-ema-bedroom-256"
model_id
=
"google/ddpm-ema-bedroom-256"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
DDIMScheduler
.
from_
config
(
model_id
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_id
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
ddpm
.
to
(
torch_device
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment