"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1b1d6444c69d8a8a9f037500bc54da735fd547d5"
Unverified Commit 02247941 authored by Cheng Lu's avatar Cheng Lu Committed by GitHub
Browse files

Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314)



* fix multistep dpmsolver for cosine schedule (deepfloy-if)

* fix a typo

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule

* add test, fix style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2dd40850
...@@ -118,6 +118,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -118,6 +118,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
lambda_min_clipped (`float`, default `-inf`):
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
cosine (squaredcos_cap_v2) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -140,6 +151,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -140,6 +151,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -187,7 +200,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -187,7 +200,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.lower_order_nums = 0 self.lower_order_nums = 0
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -197,8 +210,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,8 +210,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
timesteps = ( timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1] .round()[::-1][:-1]
.copy() .copy()
.astype(np.int64) .astype(np.int64)
...@@ -320,9 +336,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -320,9 +336,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Returns: Returns:
`torch.FloatTensor`: the converted model output. `torch.FloatTensor`: the converted model output.
""" """
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
...@@ -343,6 +363,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -343,6 +363,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output return model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
......
...@@ -113,6 +113,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,6 +113,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`): lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable
this to use up all the function evaluations. this to use up all the function evaluations.
lambda_min_clipped (`float`, default `-inf`):
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
cosine (squaredcos_cap_v2) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs.
""" """
...@@ -135,6 +146,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -135,6 +146,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++", algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -226,8 +239,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -226,8 +239,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
timesteps = ( timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1] .round()[::-1][:-1]
.copy() .copy()
.astype(np.int64) .astype(np.int64)
...@@ -297,6 +313,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -297,6 +313,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
...@@ -317,6 +336,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -317,6 +336,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output return model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
......
...@@ -29,6 +29,8 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -29,6 +29,8 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
"algorithm_type": "dpmsolver++", "algorithm_type": "dpmsolver++",
"solver_type": "midpoint", "solver_type": "midpoint",
"lower_order_final": False, "lower_order_final": False,
"lambda_min_clipped": -float("inf"),
"variance_type": None,
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -187,6 +189,14 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -187,6 +189,14 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False) self.check_over_configs(lower_order_final=False)
def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)
def test_variance_type(self):
self.check_over_configs(variance_type=None)
self.check_over_configs(variance_type="learned_range")
def test_inference_steps(self): def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
......
...@@ -28,6 +28,8 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -28,6 +28,8 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
"sample_max_value": 1.0, "sample_max_value": 1.0,
"algorithm_type": "dpmsolver++", "algorithm_type": "dpmsolver++",
"solver_type": "midpoint", "solver_type": "midpoint",
"lambda_min_clipped": -float("inf"),
"variance_type": None,
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -179,6 +181,14 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -179,6 +181,14 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False) self.check_over_configs(lower_order_final=False)
def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)
def test_variance_type(self):
self.check_over_configs(variance_type=None)
self.check_over_configs(variance_type="learned_range")
def test_inference_steps(self): def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment