Unverified Commit 8ac6de96 authored by StAlKeR7779's avatar StAlKeR7779 Committed by GitHub
Browse files

DPM++ third order fixes (#9104)



* Fix wrong output on 3n-1 steps count

* Add sde handling to 3 order

* make

* copies

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 2be66e6a
...@@ -338,8 +338,8 @@ else: ...@@ -338,8 +338,8 @@ else:
"StableDiffusion3ControlNetPipeline", "StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline", "StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline", "StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3PAGImg2ImgPipeline", "StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline", "StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline", "StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline", "StableDiffusionAttendAndExcitePipeline",
......
...@@ -889,6 +889,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -889,6 +889,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
model_output_list: List[torch.Tensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.Tensor = None, sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -967,6 +968,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -967,6 +968,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
) )
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
...@@ -1073,7 +1083,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1073,7 +1083,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
else: else:
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
if self.lower_order_nums < self.config.solver_order: if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1 self.lower_order_nums += 1
......
...@@ -764,6 +764,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -764,6 +764,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
model_output_list: List[torch.Tensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.Tensor = None, sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -842,6 +843,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -842,6 +843,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
) )
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t return x_t
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
......
...@@ -264,6 +264,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -264,6 +264,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
orders = [1, 2] * (steps // 2) orders = [1, 2] * (steps // 2)
elif order == 1: elif order == 1:
orders = [1] * steps orders = [1] * steps
if self.config.final_sigmas_type == "zero":
orders[-1] = 1
return orders return orders
@property @property
...@@ -812,6 +816,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -812,6 +816,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
model_output_list: List[torch.Tensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.Tensor = None, sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -909,6 +914,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -909,6 +914,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
) )
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h) ** 2 - 0.5)) * D2
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t return x_t
def singlestep_dpm_solver_update( def singlestep_dpm_solver_update(
...@@ -970,7 +992,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -970,7 +992,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
elif order == 2: elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise) return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
elif order == 3: elif order == 3:
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
else: else:
raise ValueError(f"Order must be 1, 2, 3, got {order}") raise ValueError(f"Order must be 1, 2, 3, got {order}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment