Unverified Commit ac7b1716 authored by Cheng Lu's avatar Cheng Lu Committed by GitHub
Browse files

Stabilize DPM++, especially for SDXL and SDE-DPM++ (#5541)



* stabilize dpmpp for sdxl by using euler at the final step

* add lu's uniform logsnr time steps

* add test

* fix check_copies

* fix tests

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 3fc10ded
...@@ -117,9 +117,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,9 +117,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, defaults to `True`): lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
use_karras_sigmas (`bool`, *optional*, defaults to `False`): use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`.
lambda_min_clipped (`float`, defaults to `-inf`): lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule. cosine (`squaredcos_cap_v2`) noise schedule.
...@@ -154,7 +162,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -154,7 +162,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++", algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
...@@ -258,6 +268,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -258,6 +268,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
...@@ -354,6 +370,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -354,6 +370,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas return sigmas
def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Lu et al. (2022)."""
lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()
rho = 1.0 # 1.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = lambda_min ** (1 / rho)
max_inv_rho = lambda_max ** (1 / rho)
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return lambdas
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -787,8 +816,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -787,8 +816,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
lower_order_final = ( # Improve numerical stability for small number of steps
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
) )
lower_order_second = ( lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
......
...@@ -117,6 +117,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,6 +117,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, defaults to `True`): lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
use_karras_sigmas (`bool`, *optional*, defaults to `False`): use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
...@@ -154,6 +158,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -154,6 +158,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++", algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
...@@ -804,8 +809,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -804,8 +809,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
lower_order_final = ( # Improve numerical stability for small number of steps
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
) )
lower_order_second = ( lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
......
...@@ -29,6 +29,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -29,6 +29,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
"algorithm_type": "dpmsolver++", "algorithm_type": "dpmsolver++",
"solver_type": "midpoint", "solver_type": "midpoint",
"lower_order_final": False, "lower_order_final": False,
"euler_at_final": False,
"lambda_min_clipped": -float("inf"), "lambda_min_clipped": -float("inf"),
"variance_type": None, "variance_type": None,
} }
...@@ -195,6 +196,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -195,6 +196,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False) self.check_over_configs(lower_order_final=False)
def test_euler_at_final(self):
self.check_over_configs(euler_at_final=True)
self.check_over_configs(euler_at_final=False)
def test_lambda_min_clipped(self): def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf")) self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1) self.check_over_configs(lambda_min_clipped=-5.1)
...@@ -258,6 +263,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -258,6 +263,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.2096) < 1e-3 assert abs(result_mean.item() - 0.2096) < 1e-3
def test_full_loop_with_lu_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_lu_lambdas=True)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.1554) < 1e-3
def test_switch(self): def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results # make sure that iterating over schedulers with same config names gives same results
# for defaults # for defaults
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment