Unverified Commit ac61eefc authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix DPM Scheduler with `use_karras_sigmas` option (#6477)



* fix

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f95615b8
...@@ -128,6 +128,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -128,6 +128,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`. `lambda(t)`.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`): lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule. cosine (`squaredcos_cap_v2`) noise schedule.
...@@ -165,11 +168,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -165,11 +168,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
euler_at_final: bool = False, euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -207,6 +215,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -207,6 +215,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
)
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
...@@ -267,17 +280,24 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -267,17 +280,24 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.flip(sigmas).copy() sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_lu_lambdas: elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy()) lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas) sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
...@@ -831,7 +851,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -831,7 +851,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Improve numerical stability for small number of steps # Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) self.config.euler_at_final
or (self.config.lower_order_final and len(self.timesteps) < 15)
or self.config.final_sigmas_type == "zero"
) )
lower_order_second = ( lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
......
...@@ -165,6 +165,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -165,6 +165,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -783,7 +787,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -783,7 +787,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self._step_index = step_index self._step_index = step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
......
...@@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`. `algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`): algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
paper, and the `dpmsolver++` type implements the algorithms in the paper, and the `dpmsolver++` type implements the algorithms in the
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
...@@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`): use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`): lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule. cosine (`squaredcos_cap_v2`) noise schedule.
...@@ -150,9 +153,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -150,9 +153,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
): ):
if algorithm_type == "dpmsolver":
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -189,6 +197,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,6 +197,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
)
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
...@@ -267,11 +280,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -267,11 +280,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.flip(sigmas).copy() sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
...@@ -285,6 +305,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -285,6 +305,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
) )
self.register_to_config(lower_order_final=True) self.register_to_config(lower_order_final=True)
if not self.config.lower_order_final and self.config.final_sigmas_type == "zero":
logger.warn(
" `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
)
self.register_to_config(lower_order_final=True)
self.order_list = self.get_order_list(num_inference_steps) self.order_list = self.get_order_list(num_inference_steps)
# add an index counter for schedulers that allow duplicated timesteps # add an index counter for schedulers that allow duplicated timesteps
......
...@@ -32,6 +32,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -32,6 +32,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
"euler_at_final": False, "euler_at_final": False,
"lambda_min_clipped": -float("inf"), "lambda_min_clipped": -float("inf"),
"variance_type": None, "variance_type": None,
"final_sigmas_type": "sigma_min",
} }
config.update(**kwargs) config.update(**kwargs)
......
...@@ -30,6 +30,7 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -30,6 +30,7 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
"solver_type": "midpoint", "solver_type": "midpoint",
"lambda_min_clipped": -float("inf"), "lambda_min_clipped": -float("inf"),
"variance_type": None, "variance_type": None,
"final_sigmas_type": "sigma_min",
} }
config.update(**kwargs) config.update(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment