Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
769f0be8
Unverified
Commit
769f0be8
authored
Dec 02, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 02, 2022
Browse files
Finalize 2nd order schedulers (#1503)
* up * up * finish * finish * up * up * finish
parent
4f596599
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
90 additions
and
35 deletions
+90
-35
docs/source/api/schedulers.mdx
docs/source/api/schedulers.mdx
+27
-1
examples/community/sd_text2img_k_diffusion.py
examples/community/sd_text2img_k_diffusion.py
+14
-4
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-0
src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
...ffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
+1
-1
src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
...pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
...rs/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
...le_diffusion/pipeline_stable_diffusion_image_variation.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
...ble_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
...nes/stable_diffusion/pipeline_stable_diffusion_upscale.py
+1
-1
src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
...s/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
+1
-1
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+3
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+4
-4
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+4
-3
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+4
-3
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+4
-0
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+4
-0
src/diffusers/schedulers/scheduling_heun_discrete.py
src/diffusers/schedulers/scheduling_heun_discrete.py
+14
-9
No files found.
docs/source/api/schedulers.mdx
View file @
769f0be8
...
...
@@ -76,6 +76,33 @@ Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [im
[[autodoc]] DPMSolverMultistepScheduler
#### Heun scheduler inspired by Karras et. al paper
Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] HeunDiscreteScheduler
#### DPM Discrete Scheduler inspired by Karras et. al paper
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] KDPM2DiscreteScheduler
#### DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] KDPM2AncestralDiscreteScheduler
#### Variance exploding, stochastic sampling from Karras et. al
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
...
...
@@ -86,7 +113,6 @@ Original paper can be found [here](https://arxiv.org/abs/2006.11239).
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
[[autodoc]] LMSDiscreteScheduler
#### Pseudo numerical methods for diffusion models (PNDM)
...
...
examples/community/sd_text2img_k_diffusion.py
View file @
769f0be8
...
...
@@ -21,7 +21,7 @@ from diffusers import LMSDiscreteScheduler
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
from
diffusers.utils
import
is_accelerate_available
,
logging
from
k_diffusion.external
import
CompVisDenoiser
from
k_diffusion.external
import
CompVisDenoiser
,
CompVisVDenoiser
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -33,7 +33,12 @@ class ModelWrapper:
self
.
alphas_cumprod
=
alphas_cumprod
def
apply_model
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
(
*
args
,
**
kwargs
).
sample
if
len
(
args
)
==
3
:
encoder_hidden_states
=
args
[
-
1
]
args
=
args
[:
2
]
if
kwargs
.
get
(
"cond"
,
None
)
is
not
None
:
encoder_hidden_states
=
kwargs
.
pop
(
"cond"
)
return
self
.
model
(
*
args
,
encoder_hidden_states
=
encoder_hidden_states
,
**
kwargs
).
sample
class
StableDiffusionPipeline
(
DiffusionPipeline
):
...
...
@@ -63,6 +68,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
self
,
...
...
@@ -99,6 +105,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
)
model
=
ModelWrapper
(
unet
,
scheduler
.
alphas_cumprod
)
if
scheduler
.
prediction_type
==
"v_prediction"
:
self
.
k_diffusion_model
=
CompVisVDenoiser
(
model
)
else
:
self
.
k_diffusion_model
=
CompVisDenoiser
(
model
)
def
set_sampler
(
self
,
scheduler_type
:
str
):
...
...
@@ -417,6 +426,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
# 4. Prepare timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
text_embeddings
.
device
)
sigmas
=
self
.
scheduler
.
sigmas
sigmas
=
sigmas
.
to
(
text_embeddings
.
dtype
)
# 5. Prepare latent variables
num_channels_latents
=
self
.
unet
.
in_channels
...
...
@@ -437,7 +447,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
def
model_fn
(
x
,
t
):
latent_model_input
=
torch
.
cat
([
x
]
*
2
)
noise_pred
=
self
.
k_diffusion_model
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
)
noise_pred
=
self
.
k_diffusion_model
(
latent_model_input
,
t
,
cond
=
text_embeddings
)
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
...
...
src/diffusers/__init__.py
View file @
769f0be8
...
...
@@ -49,6 +49,8 @@ if is_torch_available():
HeunDiscreteScheduler
,
IPNDMScheduler
,
KarrasVeScheduler
,
KDPM2AncestralDiscreteScheduler
,
KDPM2DiscreteScheduler
,
PNDMScheduler
,
RePaintScheduler
,
SchedulerMixin
,
...
...
src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
View file @
769f0be8
...
...
@@ -558,7 +558,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
View file @
769f0be8
...
...
@@ -580,7 +580,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
View file @
769f0be8
...
...
@@ -666,7 +666,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
769f0be8
...
...
@@ -557,7 +557,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
View file @
769f0be8
...
...
@@ -440,7 +440,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
769f0be8
...
...
@@ -587,7 +587,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
769f0be8
...
...
@@ -701,7 +701,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
View file @
769f0be8
...
...
@@ -602,7 +602,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
View file @
769f0be8
...
...
@@ -515,7 +515,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
View file @
769f0be8
...
...
@@ -711,7 +711,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
:
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
)
:
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
...
...
src/diffusers/schedulers/__init__.py
View file @
769f0be8
...
...
@@ -22,8 +22,10 @@ if is_torch_available():
from
.scheduling_dpmsolver_multistep
import
DPMSolverMultistepScheduler
from
.scheduling_euler_ancestral_discrete
import
EulerAncestralDiscreteScheduler
from
.scheduling_euler_discrete
import
EulerDiscreteScheduler
from
.scheduling_heun
import
HeunDiscreteScheduler
from
.scheduling_heun
_discrete
import
HeunDiscreteScheduler
from
.scheduling_ipndm
import
IPNDMScheduler
from
.scheduling_k_dpm_2_ancestral_discrete
import
KDPM2AncestralDiscreteScheduler
from
.scheduling_k_dpm_2_discrete
import
KDPM2DiscreteScheduler
from
.scheduling_karras_ve
import
KarrasVeScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_repaint
import
RePaintScheduler
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
769f0be8
...
...
@@ -106,10 +106,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model
predict
s
the noise
(epsilon), or the samples. One of `ep
si
l
on
`, `sample`.
`v-prediction` is not supported for this scheduler.
prediction_type (`str`, default `epsilon`
, optional
):
prediction type of the scheduler function, one of `epsilon` (
predict
ing
the noise
of the diffu
sion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
769f0be8
...
...
@@ -99,9 +99,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
769f0be8
...
...
@@ -87,9 +87,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
or `v-prediction`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
...
...
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
769f0be8
...
...
@@ -64,6 +64,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
769f0be8
...
...
@@ -65,6 +65,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
...
...
src/diffusers/schedulers/scheduling_heun.py
→
src/diffusers/schedulers/scheduling_heun
_discrete
.py
View file @
769f0be8
...
...
@@ -24,14 +24,16 @@ from .scheduling_utils import SchedulerMixin, SchedulerOutput
class
HeunDiscreteScheduler
(
SchedulerMixin
,
ConfigMixin
):
"""
Args:
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
...
...
@@ -40,7 +42,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
...
...
@@ -77,7 +82,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def
index_for_timestep
(
self
,
timestep
):
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
pos
=
0
if
indices
.
shape
[
0
]
<
2
else
1
pos
=
-
1
else
:
pos
=
0
return
indices
[
pos
].
item
()
...
...
@@ -132,7 +137,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
init_noise_sigma
=
self
.
sigmas
.
max
()
timesteps
=
torch
.
from_numpy
(
timesteps
)
timesteps
=
torch
.
cat
([
timesteps
[:
1
],
timesteps
[
1
:].
repeat_interleave
(
2
)
,
timesteps
[
-
1
:]
])
timesteps
=
torch
.
cat
([
timesteps
[:
1
],
timesteps
[
1
:].
repeat_interleave
(
2
)])
if
str
(
device
).
startswith
(
"mps"
):
# mps does not support float64
...
...
@@ -199,9 +204,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
if
self
.
state_in_first_order
:
# 2. Convert to an ODE derivative
# 2. Convert to an ODE derivative
for 1st order
derivative
=
(
sample
-
pred_original_sample
)
/
sigma_hat
# 3.
1st order derivative
# 3.
delta timestep
dt
=
sigma_next
-
sigma_hat
# store for 2nd order step
...
...
@@ -213,7 +218,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
derivative
=
(
sample
-
pred_original_sample
)
/
sigma_next
derivative
=
(
self
.
prev_derivative
+
derivative
)
/
2
# 3.
Retrieve 1st order derivativ
e
# 3.
take prev timestep & sampl
e
dt
=
self
.
dt
sample
=
self
.
sample
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment