Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
769f0be8
Unverified
Commit
769f0be8
authored
Dec 02, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 02, 2022
Browse files
Finalize 2nd order schedulers (#1503)
* up * up * finish * finish * up * up * finish
parent
4f596599
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
930 additions
and
1 deletion
+930
-1
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
...users/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+324
-0
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+299
-0
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+4
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+4
-0
src/diffusers/utils/dummy_pt_objects.py
src/diffusers/utils/dummy_pt_objects.py
+30
-0
tests/test_scheduler.py
tests/test_scheduler.py
+269
-0
No files found.
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
0 → 100644
View file @
769f0be8
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
class
KDPM2AncestralDiscreteScheduler
(
SchedulerMixin
,
ConfigMixin
):
"""
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
order
=
2
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
,
beta_start
:
float
=
0.00085
,
# sensible defaults
beta_end
:
float
=
0.012
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]]
=
None
,
prediction_type
:
str
=
"epsilon"
,
):
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
# this schedule is very specific to the latent diffusion model.
self
.
betas
=
(
torch
.
linspace
(
beta_start
**
0.5
,
beta_end
**
0.5
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
**
2
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
torch
.
cumprod
(
self
.
alphas
,
dim
=
0
)
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
def
index_for_timestep
(
self
,
timestep
):
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
pos
=
-
1
else
:
pos
=
0
return
indices
[
pos
].
item
()
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
],
)
->
torch
.
FloatTensor
:
"""
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
step_index
=
self
.
index_for_timestep
(
timestep
)
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
step_index
]
else
:
sigma
=
self
.
sigmas_interpol
[
step_index
-
1
]
sample
=
sample
/
((
sigma
**
2
+
1
)
**
0.5
)
return
sample
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
num_train_timesteps
:
Optional
[
int
]
=
None
,
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self
.
num_inference_steps
=
num_inference_steps
num_train_timesteps
=
num_train_timesteps
or
self
.
config
.
num_train_timesteps
timesteps
=
np
.
linspace
(
0
,
num_train_timesteps
-
1
,
num_inference_steps
,
dtype
=
float
)[::
-
1
].
copy
()
sigmas
=
np
.
array
(((
1
-
self
.
alphas_cumprod
)
/
self
.
alphas_cumprod
)
**
0.5
)
self
.
log_sigmas
=
torch
.
from_numpy
(
np
.
log
(
sigmas
)).
to
(
device
)
sigmas
=
np
.
interp
(
timesteps
,
np
.
arange
(
0
,
len
(
sigmas
)),
sigmas
)
sigmas
=
np
.
concatenate
([
sigmas
,
[
0.0
]]).
astype
(
np
.
float32
)
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
# compute up and down sigmas
sigmas_next
=
sigmas
.
roll
(
-
1
)
sigmas_next
[
-
1
]
=
0.0
sigmas_up
=
(
sigmas_next
**
2
*
(
sigmas
**
2
-
sigmas_next
**
2
)
/
sigmas
**
2
)
**
0.5
sigmas_down
=
(
sigmas_next
**
2
-
sigmas_up
**
2
)
**
0.5
sigmas_down
[
-
1
]
=
0.0
# compute interpolated sigmas
sigmas_interpol
=
sigmas
.
log
().
lerp
(
sigmas_down
.
log
(),
0.5
).
exp
()
sigmas_interpol
[
-
2
:]
=
0.0
# set sigmas
self
.
sigmas
=
torch
.
cat
([
sigmas
[:
1
],
sigmas
[
1
:].
repeat_interleave
(
2
),
sigmas
[
-
1
:]])
self
.
sigmas_interpol
=
torch
.
cat
(
[
sigmas_interpol
[:
1
],
sigmas_interpol
[
1
:].
repeat_interleave
(
2
),
sigmas_interpol
[
-
1
:]]
)
self
.
sigmas_up
=
torch
.
cat
([
sigmas_up
[:
1
],
sigmas_up
[
1
:].
repeat_interleave
(
2
),
sigmas_up
[
-
1
:]])
self
.
sigmas_down
=
torch
.
cat
([
sigmas_down
[:
1
],
sigmas_down
[
1
:].
repeat_interleave
(
2
),
sigmas_down
[
-
1
:]])
# standard deviation of the initial noise distribution
self
.
init_noise_sigma
=
self
.
sigmas
.
max
()
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
)
timesteps_interpol
=
self
.
sigma_to_t
(
sigmas_interpol
).
to
(
device
)
interleaved_timesteps
=
torch
.
stack
((
timesteps_interpol
[:
-
2
,
None
],
timesteps
[
1
:,
None
]),
dim
=-
1
).
flatten
()
timesteps
=
torch
.
cat
([
timesteps
[:
1
],
interleaved_timesteps
])
if
str
(
device
).
startswith
(
"mps"
):
# mps does not support float64
self
.
timesteps
=
timesteps
.
to
(
device
,
dtype
=
torch
.
float32
)
else
:
self
.
timesteps
=
timesteps
self
.
sample
=
None
def
sigma_to_t
(
self
,
sigma
):
# get log sigma
log_sigma
=
sigma
.
log
()
# get distribution
dists
=
log_sigma
-
self
.
log_sigmas
[:,
None
]
# get sigmas range
low_idx
=
dists
.
ge
(
0
).
cumsum
(
dim
=
0
).
argmax
(
dim
=
0
).
clamp
(
max
=
self
.
log_sigmas
.
shape
[
0
]
-
2
)
high_idx
=
low_idx
+
1
low
=
self
.
log_sigmas
[
low_idx
]
high
=
self
.
log_sigmas
[
high_idx
]
# interpolate sigmas
w
=
(
low
-
log_sigma
)
/
(
low
-
high
)
w
=
w
.
clamp
(
0
,
1
)
# transform interpolation to time range
t
=
(
1
-
w
)
*
low_idx
+
w
*
high_idx
t
=
t
.
view
(
sigma
.
shape
)
return
t
@
property
def
state_in_first_order
(
self
):
return
self
.
sample
is
None
def
step
(
self
,
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
Union
[
float
,
torch
.
FloatTensor
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
generator
:
Optional
[
torch
.
Generator
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
SchedulerOutput
,
Tuple
]:
"""
Args:
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index
=
self
.
index_for_timestep
(
timestep
)
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
step_index
]
sigma_up
=
self
.
sigmas_up
[
step_index
]
sigma_down
=
self
.
sigmas_down
[
step_index
-
1
]
else
:
# 2nd order / KPDM2's method
sigma
=
self
.
sigmas
[
step_index
-
1
]
sigma_interpol
=
self
.
sigmas_interpol
[
step_index
-
1
]
sigma_up
=
self
.
sigmas_up
[
step_index
-
1
]
sigma_down
=
self
.
sigmas_down
[
step_index
-
1
]
# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
# passing it to the model which requires a change in API
gamma
=
0
sigma_hat
=
sigma
*
(
gamma
+
1
)
# Note: sigma_hat == sigma for now
device
=
model_output
.
device
if
device
.
type
==
"mps"
:
# randn does not work reproducibly on mps
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
"cpu"
,
generator
=
generator
).
to
(
device
)
else
:
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
device
,
generator
=
generator
).
to
(
device
)
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if
self
.
config
.
prediction_type
==
"epsilon"
:
sigma_input
=
sigma_hat
if
self
.
state_in_first_order
else
sigma_interpol
pred_original_sample
=
sample
-
sigma_input
*
model_output
elif
self
.
config
.
prediction_type
==
"v_prediction"
:
sigma_input
=
sigma_hat
if
self
.
state_in_first_order
else
sigma_interpol
pred_original_sample
=
model_output
*
(
-
sigma_input
/
(
sigma_input
**
2
+
1
)
**
0.5
)
+
(
sample
/
(
sigma_input
**
2
+
1
)
)
else
:
raise
ValueError
(
f
"prediction_type given as
{
self
.
config
.
prediction_type
}
must be one of `epsilon`, or `v_prediction`"
)
if
self
.
state_in_first_order
:
# 2. Convert to an ODE derivative for 1st order
derivative
=
(
sample
-
pred_original_sample
)
/
sigma_hat
# 3. delta timestep
dt
=
sigma_interpol
-
sigma_hat
# store for 2nd order step
self
.
sample
=
sample
self
.
dt
=
dt
prev_sample
=
sample
+
derivative
*
dt
else
:
# DPM-Solver-2
# 2. Convert to an ODE derivative for 2nd order
derivative
=
(
sample
-
pred_original_sample
)
/
sigma_interpol
# 3. delta timestep
dt
=
sigma_down
-
sigma_hat
sample
=
self
.
sample
self
.
sample
=
None
prev_sample
=
sample
+
derivative
*
dt
prev_sample
=
prev_sample
+
noise
*
sigma_up
if
not
return_dict
:
return
(
prev_sample
,)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
self
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
self
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
)
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
noisy_samples
=
original_samples
+
noise
*
sigma
return
noisy_samples
def
__len__
(
self
):
return
self
.
config
.
num_train_timesteps
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
0 → 100644
View file @
769f0be8
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
class
KDPM2DiscreteScheduler
(
SchedulerMixin
,
ConfigMixin
):
"""
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
order
=
2
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
,
beta_start
:
float
=
0.00085
,
# sensible defaults
beta_end
:
float
=
0.012
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]]
=
None
,
prediction_type
:
str
=
"epsilon"
,
):
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
# this schedule is very specific to the latent diffusion model.
self
.
betas
=
(
torch
.
linspace
(
beta_start
**
0.5
,
beta_end
**
0.5
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
**
2
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
torch
.
cumprod
(
self
.
alphas
,
dim
=
0
)
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
def
index_for_timestep
(
self
,
timestep
):
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
pos
=
-
1
else
:
pos
=
0
return
indices
[
pos
].
item
()
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
],
)
->
torch
.
FloatTensor
:
"""
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
step_index
=
self
.
index_for_timestep
(
timestep
)
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
step_index
]
else
:
sigma
=
self
.
sigmas_interpol
[
step_index
]
sample
=
sample
/
((
sigma
**
2
+
1
)
**
0.5
)
return
sample
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
num_train_timesteps
:
Optional
[
int
]
=
None
,
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self
.
num_inference_steps
=
num_inference_steps
num_train_timesteps
=
num_train_timesteps
or
self
.
config
.
num_train_timesteps
timesteps
=
np
.
linspace
(
0
,
num_train_timesteps
-
1
,
num_inference_steps
,
dtype
=
float
)[::
-
1
].
copy
()
sigmas
=
np
.
array
(((
1
-
self
.
alphas_cumprod
)
/
self
.
alphas_cumprod
)
**
0.5
)
self
.
log_sigmas
=
torch
.
from_numpy
(
np
.
log
(
sigmas
)).
to
(
device
)
sigmas
=
np
.
interp
(
timesteps
,
np
.
arange
(
0
,
len
(
sigmas
)),
sigmas
)
sigmas
=
np
.
concatenate
([
sigmas
,
[
0.0
]]).
astype
(
np
.
float32
)
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
# interpolate sigmas
sigmas_interpol
=
sigmas
.
log
().
lerp
(
sigmas
.
roll
(
1
).
log
(),
0.5
).
exp
()
self
.
sigmas
=
torch
.
cat
([
sigmas
[:
1
],
sigmas
[
1
:].
repeat_interleave
(
2
),
sigmas
[
-
1
:]])
self
.
sigmas_interpol
=
torch
.
cat
(
[
sigmas_interpol
[:
1
],
sigmas_interpol
[
1
:].
repeat_interleave
(
2
),
sigmas_interpol
[
-
1
:]]
)
# standard deviation of the initial noise distribution
self
.
init_noise_sigma
=
self
.
sigmas
.
max
()
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
)
# interpolate timesteps
timesteps_interpol
=
self
.
sigma_to_t
(
sigmas_interpol
).
to
(
device
)
interleaved_timesteps
=
torch
.
stack
((
timesteps_interpol
[
1
:
-
1
,
None
],
timesteps
[
1
:,
None
]),
dim
=-
1
).
flatten
()
timesteps
=
torch
.
cat
([
timesteps
[:
1
],
interleaved_timesteps
])
if
str
(
device
).
startswith
(
"mps"
):
# mps does not support float64
self
.
timesteps
=
timesteps
.
to
(
torch
.
float32
)
else
:
self
.
timesteps
=
timesteps
self
.
sample
=
None
def
sigma_to_t
(
self
,
sigma
):
# get log sigma
log_sigma
=
sigma
.
log
()
# get distribution
dists
=
log_sigma
-
self
.
log_sigmas
[:,
None
]
# get sigmas range
low_idx
=
dists
.
ge
(
0
).
cumsum
(
dim
=
0
).
argmax
(
dim
=
0
).
clamp
(
max
=
self
.
log_sigmas
.
shape
[
0
]
-
2
)
high_idx
=
low_idx
+
1
low
=
self
.
log_sigmas
[
low_idx
]
high
=
self
.
log_sigmas
[
high_idx
]
# interpolate sigmas
w
=
(
low
-
log_sigma
)
/
(
low
-
high
)
w
=
w
.
clamp
(
0
,
1
)
# transform interpolation to time range
t
=
(
1
-
w
)
*
low_idx
+
w
*
high_idx
t
=
t
.
view
(
sigma
.
shape
)
return
t
@
property
def
state_in_first_order
(
self
):
return
self
.
sample
is
None
def
step
(
self
,
model_output
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
Union
[
float
,
torch
.
FloatTensor
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
return_dict
:
bool
=
True
,
)
->
Union
[
SchedulerOutput
,
Tuple
]:
"""
Args:
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index
=
self
.
index_for_timestep
(
timestep
)
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
step_index
+
1
]
sigma_next
=
self
.
sigmas
[
step_index
+
1
]
else
:
# 2nd order / KDPM2's method
sigma
=
self
.
sigmas
[
step_index
-
1
]
sigma_interpol
=
self
.
sigmas_interpol
[
step_index
]
sigma_next
=
self
.
sigmas
[
step_index
]
# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
# passing it to the model which requires a change in API
gamma
=
0
sigma_hat
=
sigma
*
(
gamma
+
1
)
# Note: sigma_hat == sigma for now
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if
self
.
config
.
prediction_type
==
"epsilon"
:
sigma_input
=
sigma_hat
if
self
.
state_in_first_order
else
sigma_interpol
pred_original_sample
=
sample
-
sigma_input
*
model_output
elif
self
.
config
.
prediction_type
==
"v_prediction"
:
sigma_input
=
sigma_hat
if
self
.
state_in_first_order
else
sigma_interpol
pred_original_sample
=
model_output
*
(
-
sigma_input
/
(
sigma_input
**
2
+
1
)
**
0.5
)
+
(
sample
/
(
sigma_input
**
2
+
1
)
)
else
:
raise
ValueError
(
f
"prediction_type given as
{
self
.
config
.
prediction_type
}
must be one of `epsilon`, or `v_prediction`"
)
if
self
.
state_in_first_order
:
# 2. Convert to an ODE derivative for 1st order
derivative
=
(
sample
-
pred_original_sample
)
/
sigma_hat
# 3. delta timestep
dt
=
sigma_interpol
-
sigma_hat
# store for 2nd order step
self
.
sample
=
sample
else
:
# DPM-Solver-2
# 2. Convert to an ODE derivative for 2nd order
derivative
=
(
sample
-
pred_original_sample
)
/
sigma_interpol
# 3. delta timestep
dt
=
sigma_next
-
sigma_hat
sample
=
self
.
sample
self
.
sample
=
None
prev_sample
=
sample
+
derivative
*
dt
if
not
return_dict
:
return
(
prev_sample
,)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
self
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
self
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
)
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
noisy_samples
=
original_samples
+
noise
*
sigma
return
noisy_samples
def
__len__
(
self
):
return
self
.
config
.
num_train_timesteps
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
769f0be8
...
@@ -64,7 +64,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -64,7 +64,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
769f0be8
...
@@ -82,6 +82,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -82,6 +82,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
otherwise it uses the value of alpha at step 0.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
steps_offset (`int`, default `0`):
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
...
...
src/diffusers/utils/dummy_pt_objects.py
View file @
769f0be8
...
@@ -407,6 +407,36 @@ class KarrasVeScheduler(metaclass=DummyObject):
...
@@ -407,6 +407,36 @@ class KarrasVeScheduler(metaclass=DummyObject):
requires_backends
(
cls
,
[
"torch"
])
requires_backends
(
cls
,
[
"torch"
])
class
KDPM2AncestralDiscreteScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
])
class
KDPM2DiscreteScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
])
class
PNDMScheduler
(
metaclass
=
DummyObject
):
class
PNDMScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
_backends
=
[
"torch"
]
...
...
tests/test_scheduler.py
View file @
769f0be8
...
@@ -32,6 +32,8 @@ from diffusers import (
...
@@ -32,6 +32,8 @@ from diffusers import (
EulerDiscreteScheduler
,
EulerDiscreteScheduler
,
HeunDiscreteScheduler
,
HeunDiscreteScheduler
,
IPNDMScheduler
,
IPNDMScheduler
,
KDPM2AncestralDiscreteScheduler
,
KDPM2DiscreteScheduler
,
LMSDiscreteScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
,
PNDMScheduler
,
ScoreSdeVeScheduler
,
ScoreSdeVeScheduler
,
...
@@ -2197,3 +2199,270 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
...
@@ -2197,3 +2199,270 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
# CUDA
# CUDA
assert
abs
(
result_sum
.
item
()
-
0.1233
)
<
1e-2
assert
abs
(
result_sum
.
item
()
-
0.1233
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.0002
)
<
1e-3
assert
abs
(
result_mean
.
item
()
-
0.0002
)
<
1e-3
class
KDPM2DiscreteSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
KDPM2DiscreteScheduler
,)
num_inference_steps
=
10
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
"num_train_timesteps"
:
1100
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
}
config
.
update
(
**
kwargs
)
return
config
def
test_timesteps
(
self
):
for
timesteps
in
[
10
,
50
,
100
,
1000
]:
self
.
check_over_configs
(
num_train_timesteps
=
timesteps
)
def
test_betas
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.00001
,
0.0001
,
0.001
],
[
0.0002
,
0.002
,
0.02
]):
self
.
check_over_configs
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_schedules
(
self
):
for
schedule
in
[
"linear"
,
"scaled_linear"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_prediction_type
(
self
):
for
prediction_type
in
[
"epsilon"
,
"v_prediction"
]:
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_full_loop_with_v_prediction
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
(
prediction_type
=
"v_prediction"
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
self
.
num_inference_steps
)
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
*
scheduler
.
init_noise_sigma
sample
=
sample
.
to
(
torch_device
)
for
i
,
t
in
enumerate
(
scheduler
.
timesteps
):
sample
=
scheduler
.
scale_model_input
(
sample
,
t
)
model_output
=
model
(
sample
,
t
)
output
=
scheduler
.
step
(
model_output
,
t
,
sample
)
sample
=
output
.
prev_sample
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
if
torch_device
in
[
"cpu"
,
"mps"
]:
assert
abs
(
result_sum
.
item
()
-
4.6934e-07
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
6.1112e-10
)
<
1e-3
else
:
# CUDA
assert
abs
(
result_sum
.
item
()
-
4.693428650170972e-07
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.0002
)
<
1e-3
def
test_full_loop_no_noise
(
self
):
if
torch_device
==
"mps"
:
return
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
self
.
num_inference_steps
)
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
*
scheduler
.
init_noise_sigma
sample
=
sample
.
to
(
torch_device
)
for
i
,
t
in
enumerate
(
scheduler
.
timesteps
):
sample
=
scheduler
.
scale_model_input
(
sample
,
t
)
model_output
=
model
(
sample
,
t
)
output
=
scheduler
.
step
(
model_output
,
t
,
sample
)
sample
=
output
.
prev_sample
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
if
torch_device
in
[
"cpu"
,
"mps"
]:
assert
abs
(
result_sum
.
item
()
-
20.4125
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.0266
)
<
1e-3
else
:
# CUDA
assert
abs
(
result_sum
.
item
()
-
20.4125
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.0266
)
<
1e-3
def
test_full_loop_device
(
self
):
if
torch_device
==
"mps"
:
return
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
self
.
num_inference_steps
,
device
=
torch_device
)
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
.
to
(
torch_device
)
*
scheduler
.
init_noise_sigma
for
t
in
scheduler
.
timesteps
:
sample
=
scheduler
.
scale_model_input
(
sample
,
t
)
model_output
=
model
(
sample
,
t
)
output
=
scheduler
.
step
(
model_output
,
t
,
sample
)
sample
=
output
.
prev_sample
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
if
str
(
torch_device
).
startswith
(
"cpu"
):
# The following sum varies between 148 and 156 on mps. Why?
assert
abs
(
result_sum
.
item
()
-
20.4125
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.0266
)
<
1e-3
else
:
# CUDA
assert
abs
(
result_sum
.
item
()
-
20.4125
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.0266
)
<
1e-3
class
KDPM2AncestralDiscreteSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
KDPM2AncestralDiscreteScheduler
,)
num_inference_steps
=
10
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
"num_train_timesteps"
:
1100
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
}
config
.
update
(
**
kwargs
)
return
config
def
test_timesteps
(
self
):
for
timesteps
in
[
10
,
50
,
100
,
1000
]:
self
.
check_over_configs
(
num_train_timesteps
=
timesteps
)
def
test_betas
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.00001
,
0.0001
,
0.001
],
[
0.0002
,
0.002
,
0.02
]):
self
.
check_over_configs
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_schedules
(
self
):
for
schedule
in
[
"linear"
,
"scaled_linear"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_full_loop_no_noise
(
self
):
if
torch_device
==
"mps"
:
return
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
self
.
num_inference_steps
)
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
*
scheduler
.
init_noise_sigma
sample
=
sample
.
to
(
torch_device
)
for
i
,
t
in
enumerate
(
scheduler
.
timesteps
):
sample
=
scheduler
.
scale_model_input
(
sample
,
t
)
model_output
=
model
(
sample
,
t
)
output
=
scheduler
.
step
(
model_output
,
t
,
sample
,
generator
=
generator
)
sample
=
output
.
prev_sample
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
if
torch_device
in
[
"cpu"
,
"mps"
]:
assert
abs
(
result_sum
.
item
()
-
13849.3945
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
18.0331
)
<
5e-3
else
:
# CUDA
assert
abs
(
result_sum
.
item
()
-
13913.0449
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
18.1159
)
<
5e-3
def
test_prediction_type
(
self
):
for
prediction_type
in
[
"epsilon"
,
"v_prediction"
]:
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_full_loop_with_v_prediction
(
self
):
if
torch_device
==
"mps"
:
return
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
(
prediction_type
=
"v_prediction"
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
self
.
num_inference_steps
)
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
*
scheduler
.
init_noise_sigma
sample
=
sample
.
to
(
torch_device
)
if
torch_device
==
"mps"
:
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
for
i
,
t
in
enumerate
(
scheduler
.
timesteps
):
sample
=
scheduler
.
scale_model_input
(
sample
,
t
)
model_output
=
model
(
sample
,
t
)
output
=
scheduler
.
step
(
model_output
,
t
,
sample
,
generator
=
generator
)
sample
=
output
.
prev_sample
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
if
torch_device
in
[
"cpu"
,
"mps"
]:
assert
abs
(
result_sum
.
item
()
-
328.9970
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.4284
)
<
1e-3
else
:
# CUDA
assert
abs
(
result_sum
.
item
()
-
327.8027
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.4268
)
<
1e-3
def
test_full_loop_device
(
self
):
if
torch_device
==
"mps"
:
return
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
self
.
num_inference_steps
,
device
=
torch_device
)
if
torch_device
==
"mps"
:
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
.
to
(
torch_device
)
*
scheduler
.
init_noise_sigma
for
t
in
scheduler
.
timesteps
:
sample
=
scheduler
.
scale_model_input
(
sample
,
t
)
model_output
=
model
(
sample
,
t
)
output
=
scheduler
.
step
(
model_output
,
t
,
sample
,
generator
=
generator
)
sample
=
output
.
prev_sample
result_sum
=
torch
.
sum
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
if
str
(
torch_device
).
startswith
(
"cpu"
):
assert
abs
(
result_sum
.
item
()
-
13849.3945
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
18.0331
)
<
5e-3
else
:
# CUDA
assert
abs
(
result_sum
.
item
()
-
13913.0332
)
<
1e-1
assert
abs
(
result_mean
.
item
()
-
18.1159
)
<
1e-3
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment