Unverified Commit e87bf629 authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

[`Cont'd`] Add the SDE variant of ~~DPM-Solver~~ and DPM-Solver++ to DPM Single Step (#8269)



* Add the SDE variant of DPM-Solver and DPM-Solver++ to DPM Single Step


---------
Co-authored-by: default avatarcmdr2 <secondary.cmdr2@gmail.com>
parent 3b37fefe
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate, logging from ..utils import deprecate, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
...@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`. `algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`): algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
Stable Diffusion. sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`): solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
...@@ -186,7 +187,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -186,7 +187,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
# settings for DPM-Solver # settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
if algorithm_type == "deis": if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++") self.register_to_config(algorithm_type="dpmsolver++")
else: else:
...@@ -197,7 +198,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,7 +198,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
raise ValueError( raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
) )
...@@ -493,10 +494,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -493,10 +494,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
) )
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output. # DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]: if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3] model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
...@@ -517,34 +518,43 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -517,34 +518,43 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
x0_pred = self._threshold_sample(x0_pred) x0_pred = self._threshold_sample(x0_pred)
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output. # DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]: if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3] epsilon = model_output[:, :3]
return model_output else:
epsilon = model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverSinglestepScheduler." " `v_prediction` for the DPMSolverSinglestepScheduler."
) )
if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
return epsilon
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
*args, *args,
sample: torch.Tensor = None, sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -594,6 +604,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -594,6 +604,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t return x_t
def singlestep_dpm_solver_second_order_update( def singlestep_dpm_solver_second_order_update(
...@@ -601,6 +618,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -601,6 +618,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
model_output_list: List[torch.Tensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.Tensor = None, sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -688,6 +706,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -688,6 +706,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
- (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
) )
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t return x_t
def singlestep_dpm_solver_third_order_update( def singlestep_dpm_solver_third_order_update(
...@@ -800,6 +834,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -800,6 +834,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
*args, *args,
sample: torch.Tensor = None, sample: torch.Tensor = None,
order: int = None, order: int = None,
noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -848,9 +883,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -848,9 +883,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
) )
if order == 1: if order == 1:
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise)
elif order == 2: elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
elif order == 3: elif order == 3:
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
else: else:
...@@ -894,6 +929,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -894,6 +929,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.Tensor, sample: torch.Tensor,
generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -929,6 +965,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -929,6 +965,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output self.model_outputs[-1] = model_output
if self.config.algorithm_type == "sde-dpmsolver++":
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
else:
noise = None
order = self.order_list[self.step_index] order = self.order_list[self.step_index]
# For img2img denoising might start with order>1 which is not possible # For img2img denoising might start with order>1 which is not possible
...@@ -940,9 +983,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -940,9 +983,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
if order == 1: if order == 1:
self.sample = sample self.sample = sample
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) prev_sample = self.singlestep_dpm_solver_update(
self.model_outputs, sample=self.sample, order=order, noise=noise
)
# upon completion increase step index by one # upon completion increase step index by one, noise=noise
self._step_index += 1 self._step_index += 1
if not return_dict: if not return_dict:
......
...@@ -194,16 +194,20 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -194,16 +194,20 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_solver_order_and_type(self): def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]: for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
for solver_type in ["midpoint", "heun"]: for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]: for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]: for prediction_type in ["epsilon", "sample"]:
self.check_over_configs( if algorithm_type == "sde-dpmsolver++":
solver_order=order, if order == 3:
solver_type=solver_type, continue
prediction_type=prediction_type, else:
algorithm_type=algorithm_type, self.check_over_configs(
) solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop( sample = self.full_loop(
solver_order=order, solver_order=order,
solver_type=solver_type, solver_type=solver_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment