Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
554b374d
Commit
554b374d
authored
Nov 15, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d5ab55e4
a0520193
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
530 additions
and
70 deletions
+530
-70
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+4
-11
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+4
-11
src/diffusers/schedulers/scheduling_ipndm.py
src/diffusers/schedulers/scheduling_ipndm.py
+2
-2
src/diffusers/schedulers/scheduling_karras_ve.py
src/diffusers/schedulers/scheduling_karras_ve.py
+2
-2
src/diffusers/schedulers/scheduling_karras_ve_flax.py
src/diffusers/schedulers/scheduling_karras_ve_flax.py
+2
-2
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+4
-11
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+10
-3
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+4
-10
src/diffusers/schedulers/scheduling_pndm_flax.py
src/diffusers/schedulers/scheduling_pndm_flax.py
+10
-3
src/diffusers/schedulers/scheduling_repaint.py
src/diffusers/schedulers/scheduling_repaint.py
+2
-2
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+2
-2
src/diffusers/schedulers/scheduling_sde_ve_flax.py
src/diffusers/schedulers/scheduling_sde_ve_flax.py
+2
-2
src/diffusers/schedulers/scheduling_sde_vp.py
src/diffusers/schedulers/scheduling_sde_vp.py
+2
-2
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+111
-0
src/diffusers/schedulers/scheduling_utils_flax.py
src/diffusers/schedulers/scheduling_utils_flax.py
+119
-2
src/diffusers/schedulers/scheduling_vq_diffusion.py
src/diffusers/schedulers/scheduling_vq_diffusion.py
+2
-2
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+10
-0
tests/models/test_models_unet_1d.py
tests/models/test_models_unet_1d.py
+233
-2
tests/pipelines/dance_diffusion/test_dance_diffusion.py
tests/pipelines/dance_diffusion/test_dance_diffusion.py
+4
-0
tests/pipelines/ddim/test_ddim.py
tests/pipelines/ddim/test_ddim.py
+1
-1
No files found.
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
554b374d
...
...
@@ -19,7 +19,7 @@ import numpy as np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
logging
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
,
logging
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
...
@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes
=
[
"DDIMScheduler"
,
"DDPMScheduler"
,
"LMSDiscreteScheduler"
,
"PNDMScheduler"
,
"EulerDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
register_to_config
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
554b374d
...
...
@@ -19,7 +19,7 @@ import numpy as np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
logging
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
,
logging
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
...
@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes
=
[
"DDIMScheduler"
,
"DDPMScheduler"
,
"LMSDiscreteScheduler"
,
"PNDMScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
register_to_config
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_ipndm.py
View file @
554b374d
...
...
@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
...
...
src/diffusers/schedulers/scheduling_karras_ve.py
View file @
554b374d
...
...
@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
...
...
src/diffusers/schedulers/scheduling_karras_ve_flax.py
View file @
554b374d
...
...
@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
554b374d
...
...
@@ -21,7 +21,7 @@ import torch
from
scipy
import
integrate
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
...
@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes
=
[
"DDIMScheduler"
,
"DDPMScheduler"
,
"PNDMScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
register_to_config
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
View file @
554b374d
...
...
@@ -20,7 +20,12 @@ import jax.numpy as jnp
from
scipy
import
integrate
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
@
flax
.
struct
.
dataclass
...
...
@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
...
@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
def
has_state
(
self
):
return
True
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
554b374d
...
...
@@ -21,6 +21,7 @@ import numpy as np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
...
...
@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
...
...
@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes
=
[
"DDIMScheduler"
,
"DDPMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
register_to_config
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_pndm_flax.py
View file @
554b374d
...
...
@@ -23,7 +23,12 @@ import jax
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
:
int
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
...
...
@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
...
...
@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion.
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
def
has_state
(
self
):
return
True
...
...
src/diffusers/schedulers/scheduling_repaint.py
View file @
554b374d
...
...
@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
...
...
src/diffusers/schedulers/scheduling_sde_ve.py
View file @
554b374d
...
...
@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
...
src/diffusers/schedulers/scheduling_sde_ve_flax.py
View file @
554b374d
...
...
@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
...
src/diffusers/schedulers/scheduling_sde_vp.py
View file @
554b374d
...
...
@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more information, see the original paper: https://arxiv.org/abs/2011.13456
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
554b374d
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
torch
...
...
@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput):
class
SchedulerMixin
:
"""
Mixin containing common functions for the schedulers.
Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
"""
config_name
=
SCHEDULER_CONFIG_NAME
_compatibles
=
[]
has_compatibles
=
True
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Dict
[
str
,
Any
]
=
None
,
subfolder
:
Optional
[
str
]
=
None
,
return_unused_kwargs
=
False
,
**
kwargs
,
):
r
"""
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing the schedluer configurations saved using
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config
,
kwargs
=
cls
.
load_config
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
subfolder
=
subfolder
,
return_unused_kwargs
=
True
,
**
kwargs
,
)
return
cls
.
from_config
(
config
,
return_unused_kwargs
=
return_unused_kwargs
,
**
kwargs
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~SchedulerMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
self
.
save_config
(
save_directory
=
save_directory
,
push_to_hub
=
push_to_hub
,
**
kwargs
)
@
property
def
compatibles
(
self
):
"""
Returns all schedulers that are compatible with this scheduler
Returns:
`List[SchedulerMixin]`: List of compatible schedulers
"""
return
self
.
_get_compatibles
()
@
classmethod
def
_get_compatibles
(
cls
):
compatible_classes_str
=
list
(
set
([
cls
.
__name__
]
+
cls
.
_compatibles
))
diffusers_library
=
importlib
.
import_module
(
__name__
.
split
(
"."
)[
0
])
compatible_classes
=
[
getattr
(
diffusers_library
,
c
)
for
c
in
compatible_classes_str
if
hasattr
(
diffusers_library
,
c
)
]
return
compatible_classes
src/diffusers/schedulers/scheduling_utils_flax.py
View file @
554b374d
...
...
@@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
import
os
from
dataclasses
import
dataclass
from
typing
import
Tuple
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
jax.numpy
as
jnp
from
..utils
import
BaseOutput
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
=
[
"Flax"
+
c
for
c
in
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
]
@
dataclass
...
...
@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput):
class
FlaxSchedulerMixin
:
"""
Mixin containing common functions for the schedulers.
Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
"""
config_name
=
SCHEDULER_CONFIG_NAME
_compatibles
=
[]
has_compatibles
=
True
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Dict
[
str
,
Any
]
=
None
,
subfolder
:
Optional
[
str
]
=
None
,
return_unused_kwargs
=
False
,
**
kwargs
,
):
r
"""
Instantiate a Scheduler class from a pre-defined JSON-file.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
e.g., `./my_model_directory/`.
subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config
,
kwargs
=
cls
.
load_config
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
return_unused_kwargs
=
True
,
**
kwargs
)
scheduler
,
unused_kwargs
=
cls
.
from_config
(
config
,
return_unused_kwargs
=
True
,
**
kwargs
)
if
hasattr
(
scheduler
,
"create_state"
)
and
getattr
(
scheduler
,
"has_state"
,
False
):
state
=
scheduler
.
create_state
()
if
return_unused_kwargs
:
return
scheduler
,
state
,
unused_kwargs
return
scheduler
,
state
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~FlaxSchedulerMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
self
.
save_config
(
save_directory
=
save_directory
,
push_to_hub
=
push_to_hub
,
**
kwargs
)
@
property
def
compatibles
(
self
):
"""
Returns all schedulers that are compatible with this scheduler
Returns:
`List[SchedulerMixin]`: List of compatible schedulers
"""
return
self
.
_get_compatibles
()
@
classmethod
def
_get_compatibles
(
cls
):
compatible_classes_str
=
list
(
set
([
cls
.
__name__
]
+
cls
.
_compatibles
))
diffusers_library
=
importlib
.
import_module
(
__name__
.
split
(
"."
)[
0
])
compatible_classes
=
[
getattr
(
diffusers_library
,
c
)
for
c
in
compatible_classes_str
if
hasattr
(
diffusers_library
,
c
)
]
return
compatible_classes
def
broadcast_to_shape_from_left
(
x
:
jnp
.
ndarray
,
shape
:
Tuple
[
int
])
->
jnp
.
ndarray
:
...
...
src/diffusers/schedulers/scheduling_vq_diffusion.py
View file @
554b374d
...
...
@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2111.14822
...
...
src/diffusers/utils/__init__.py
View file @
554b374d
...
...
@@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE
=
default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME
=
"diffusers_modules"
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
=
[
"DDIMScheduler"
,
"DDPMScheduler"
,
"PNDMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
tests/models/test_models_unet_1d.py
View file @
554b374d
...
...
@@ -18,13 +18,120 @@ import unittest
import
torch
from
diffusers
import
UNet1DModel
from
diffusers.utils
import
slow
,
torch_device
from
diffusers.utils
import
floats_tensor
,
slow
,
torch_device
from
..test_modeling_common
import
ModelTesterMixin
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
UnetModel1DTests
(
unittest
.
TestCase
):
class
UNet1DModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet1DModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
14
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"timestep"
:
time_step
}
@
property
def
input_shape
(
self
):
return
(
4
,
14
,
16
)
@
property
def
output_shape
(
self
):
return
(
4
,
14
,
16
)
def
test_ema_training
(
self
):
pass
def
test_training
(
self
):
pass
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_determinism
(
self
):
super
().
test_determinism
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_outputs_equivalence
(
self
):
super
().
test_outputs_equivalence
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_pretrained_save_pretrained
(
self
):
super
().
test_from_pretrained_save_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_model_from_pretrained
(
self
):
super
().
test_model_from_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_output
(
self
):
super
().
test_output
()
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"block_out_channels"
:
(
32
,
64
,
128
,
256
),
"in_channels"
:
14
,
"out_channels"
:
14
,
"time_embedding_type"
:
"positional"
,
"use_timestep_embedding"
:
True
,
"flip_sin_to_cos"
:
False
,
"freq_shift"
:
1.0
,
"out_block_type"
:
"OutConv1DBlock"
,
"mid_block_type"
:
"MidResTemporalBlock1D"
,
"down_block_types"
:
(
"DownResnetBlock1D"
,
"DownResnetBlock1D"
,
"DownResnetBlock1D"
,
"DownResnetBlock1D"
),
"up_block_types"
:
(
"UpResnetBlock1D"
,
"UpResnetBlock1D"
,
"UpResnetBlock1D"
),
"act_fn"
:
"mish"
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNet1DModel
.
from_pretrained
(
"bglick13/hopper-medium-v2-value-function-hor32"
,
output_loading_info
=
True
,
subfolder
=
"unet"
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_output_pretrained
(
self
):
model
=
UNet1DModel
.
from_pretrained
(
"bglick13/hopper-medium-v2-value-function-hor32"
,
subfolder
=
"unet"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
in_channels
seq_len
=
16
noise
=
torch
.
randn
((
1
,
seq_len
,
num_features
)).
permute
(
0
,
2
,
1
)
# match original, we can update values and remove
time_step
=
torch
.
full
((
num_features
,),
0
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
).
sample
.
permute
(
0
,
2
,
1
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
2.137172
,
1.1426016
,
0.3688687
,
-
0.766922
,
0.7303146
,
0.11038864
,
-
0.4760633
,
0.13270172
,
0.02591348
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-3
))
def
test_forward_with_norm_groups
(
self
):
# Not implemented yet for this UNet
pass
@
slow
def
test_unet_1d_maestro
(
self
):
model_id
=
"harmonai/maestro-150k"
...
...
@@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase):
assert
(
output_sum
-
224.0896
).
abs
()
<
4e-2
assert
(
output_max
-
0.0607
).
abs
()
<
4e-4
class
UNetRLModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet1DModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
14
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"timestep"
:
time_step
}
@
property
def
input_shape
(
self
):
return
(
4
,
14
,
16
)
@
property
def
output_shape
(
self
):
return
(
4
,
14
,
1
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_determinism
(
self
):
super
().
test_determinism
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_outputs_equivalence
(
self
):
super
().
test_outputs_equivalence
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_pretrained_save_pretrained
(
self
):
super
().
test_from_pretrained_save_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_model_from_pretrained
(
self
):
super
().
test_model_from_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_output
(
self
):
# UNetRL is a value-function is different output shape
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
if
isinstance
(
output
,
dict
):
output
=
output
.
sample
self
.
assertIsNotNone
(
output
)
expected_shape
=
torch
.
Size
((
inputs_dict
[
"sample"
].
shape
[
0
],
1
))
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_ema_training
(
self
):
pass
def
test_training
(
self
):
pass
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"in_channels"
:
14
,
"out_channels"
:
14
,
"down_block_types"
:
[
"DownResnetBlock1D"
,
"DownResnetBlock1D"
,
"DownResnetBlock1D"
,
"DownResnetBlock1D"
],
"up_block_types"
:
[],
"out_block_type"
:
"ValueFunction"
,
"mid_block_type"
:
"ValueFunctionMidBlock1D"
,
"block_out_channels"
:
[
32
,
64
,
128
,
256
],
"layers_per_block"
:
1
,
"downsample_each_block"
:
True
,
"use_timestep_embedding"
:
True
,
"freq_shift"
:
1.0
,
"flip_sin_to_cos"
:
False
,
"time_embedding_type"
:
"positional"
,
"act_fn"
:
"mish"
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_pretrained_hub
(
self
):
value_function
,
vf_loading_info
=
UNet1DModel
.
from_pretrained
(
"bglick13/hopper-medium-v2-value-function-hor32"
,
output_loading_info
=
True
,
subfolder
=
"value_function"
)
self
.
assertIsNotNone
(
value_function
)
self
.
assertEqual
(
len
(
vf_loading_info
[
"missing_keys"
]),
0
)
value_function
.
to
(
torch_device
)
image
=
value_function
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_output_pretrained
(
self
):
value_function
,
vf_loading_info
=
UNet1DModel
.
from_pretrained
(
"bglick13/hopper-medium-v2-value-function-hor32"
,
output_loading_info
=
True
,
subfolder
=
"value_function"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
value_function
.
in_channels
seq_len
=
14
noise
=
torch
.
randn
((
1
,
seq_len
,
num_features
)).
permute
(
0
,
2
,
1
)
# match original, we can update values and remove
time_step
=
torch
.
full
((
num_features
,),
0
)
with
torch
.
no_grad
():
output
=
value_function
(
noise
,
time_step
).
sample
# fmt: off
expected_output_slice
=
torch
.
tensor
([
165.25
]
*
seq_len
)
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output
,
expected_output_slice
,
rtol
=
1e-3
))
def
test_forward_with_norm_groups
(
self
):
# Not implemented yet for this UNet
pass
tests/pipelines/dance_diffusion/test_dance_diffusion.py
View file @
554b374d
...
...
@@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase):
sample_rate
=
16_000
,
in_channels
=
2
,
out_channels
=
2
,
flip_sin_to_cos
=
True
,
use_timestep_embedding
=
False
,
time_embedding_type
=
"fourier"
,
mid_block_type
=
"UNetMidBlock1D"
,
down_block_types
=
[
"DownBlock1DNoSkip"
]
+
[
"DownBlock1D"
]
+
[
"AttnDownBlock1D"
],
up_block_types
=
[
"AttnUpBlock1D"
]
+
[
"UpBlock1D"
]
+
[
"UpBlock1DNoSkip"
],
)
...
...
tests/pipelines/ddim/test_ddim.py
View file @
554b374d
...
...
@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
model_id
=
"google/ddpm-ema-bedroom-256"
unet
=
UNet2DModel
.
from_pretrained
(
model_id
)
scheduler
=
DDIMScheduler
.
from_
config
(
model_id
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_id
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment