Unverified Commit 02247941 authored by Cheng Lu's avatar Cheng Lu Committed by GitHub
Browse files

Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314)



* fix multistep dpmsolver for cosine schedule (deepfloy-if)

* fix a typo

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule

* add test, fix style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2dd40850
......@@ -118,6 +118,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
lambda_min_clipped (`float`, default `-inf`):
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
cosine (squaredcos_cap_v2) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
......@@ -140,6 +151,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_type: str = "midpoint",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
......@@ -187,7 +200,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.lower_order_nums = 0
self.use_karras_sigmas = use_karras_sigmas
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
......@@ -197,8 +210,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
......@@ -320,9 +336,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Returns:
`torch.FloatTensor`: the converted model output.
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
......@@ -343,6 +363,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
......
......@@ -113,6 +113,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable
this to use up all the function evaluations.
lambda_min_clipped (`float`, default `-inf`):
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
cosine (squaredcos_cap_v2) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs.
"""
......@@ -135,6 +146,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
......@@ -226,8 +239,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
......@@ -297,6 +313,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
......@@ -317,6 +336,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
......
......@@ -29,6 +29,8 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
"lower_order_final": False,
"lambda_min_clipped": -float("inf"),
"variance_type": None,
}
config.update(**kwargs)
......@@ -187,6 +189,14 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)
def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)
def test_variance_type(self):
self.check_over_configs(variance_type=None)
self.check_over_configs(variance_type="learned_range")
def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
......
......@@ -28,6 +28,8 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
"sample_max_value": 1.0,
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
"lambda_min_clipped": -float("inf"),
"variance_type": None,
}
config.update(**kwargs)
......@@ -179,6 +181,14 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)
def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)
def test_variance_type(self):
self.check_over_configs(variance_type=None)
self.check_over_configs(variance_type="learned_range")
def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment