Unverified Commit 27522b58 authored by Cheng Lu's avatar Cheng Lu Committed by GitHub
Browse files

Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344)

* add SDE variant of DPM-Solver and DPM-Solver++

* add test

* fix typo

* fix typo
parent 8d4c7d0e
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
...@@ -70,6 +71,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -70,6 +71,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion). stable-diffusion).
We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
second-order `sde-dpmsolver++`.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
...@@ -103,10 +108,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -103,10 +108,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++`. `algorithm_type="dpmsolver++`.
algorithm_type (`str`, default `dpmsolver++`): algorithm_type (`str`, default `dpmsolver++`):
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
sampling (e.g. stable-diffusion). `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
solver_type (`str`, default `midpoint`): solver_type (`str`, default `midpoint`):
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
...@@ -180,7 +185,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -180,7 +185,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
# settings for DPM-Solver # settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis": if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++") self.register_to_config(algorithm_type="dpmsolver++")
else: else:
...@@ -212,7 +217,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -212,7 +217,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
# Clipping the minimum of all lambda(t) for numerical stability. # Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule. # This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = ( timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1] .round()[::-1][:-1]
...@@ -338,10 +343,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -338,10 +343,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output. # DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]: if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3] model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
...@@ -360,33 +365,42 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -360,33 +365,42 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
x0_pred = self._threshold_sample(x0_pred) x0_pred = self._threshold_sample(x0_pred)
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output. # DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]: if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3] epsilon = model_output[:, :3]
return model_output else:
epsilon = model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler." " `v_prediction` for the DPMSolverMultistepScheduler."
) )
if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
return epsilon
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, timestep: int,
prev_timestep: int, prev_timestep: int,
sample: torch.FloatTensor, sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the first-order DPM-Solver (equivalent to DDIM). One step for the first-order DPM-Solver (equivalent to DDIM).
...@@ -411,6 +425,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -411,6 +425,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t return x_t
def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_second_order_update(
...@@ -419,6 +447,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -419,6 +447,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
timestep_list: List[int], timestep_list: List[int],
prev_timestep: int, prev_timestep: int,
sample: torch.FloatTensor, sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the second-order multistep DPM-Solver. One step for the second-order multistep DPM-Solver.
...@@ -470,6 +499,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -470,6 +499,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
- (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
) )
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * (torch.exp(h) - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t return x_t
def multistep_dpm_solver_third_order_update( def multistep_dpm_solver_third_order_update(
...@@ -532,6 +593,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -532,6 +593,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.FloatTensor,
generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -574,12 +636,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -574,12 +636,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output self.model_outputs[-1] = model_output
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
else:
noise = None
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) prev_sample = self.dpm_solver_first_order_update(
model_output, timestep, prev_timestep, sample, noise=noise
)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep] timestep_list = [self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_second_order_update( prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
) )
else: else:
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
......
...@@ -167,16 +167,20 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -167,16 +167,20 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_solver_order_and_type(self): def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]: for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
for solver_type in ["midpoint", "heun"]: for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]: for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]: for prediction_type in ["epsilon", "sample"]:
self.check_over_configs( if algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
solver_order=order, if order == 3:
solver_type=solver_type, continue
prediction_type=prediction_type, else:
algorithm_type=algorithm_type, self.check_over_configs(
) solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop( sample = self.full_loop(
solver_order=order, solver_order=order,
solver_type=solver_type, solver_type=solver_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment