Unverified Commit f0c81562 authored by Beinsezii's avatar Beinsezii Committed by GitHub
Browse files

Add `final_sigma_zero` to UniPCMultistep (#7517)

* Add `final_sigma_zero` to UniPCMultistep

Effectively the same trick as DDIM's `set_alpha_to_one` and
DPM's `final_sigma_type='zero'`.
Currently False by default but maybe this should be True?

* `final_sigma_zero: bool` -> `final_sigmas_type: str`

Should 1:1 match DPM Multistep now.

* Set `final_sigmas_type='sigma_min'` in UniPC UTs
parent 9d20ed37
...@@ -127,6 +127,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -127,6 +127,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families. An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -153,6 +156,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -153,6 +156,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -265,10 +269,25 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -265,10 +269,25 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.flip(sigmas).copy() sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas)
......
...@@ -24,6 +24,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -24,6 +24,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
"beta_schedule": "linear", "beta_schedule": "linear",
"solver_order": 2, "solver_order": 2,
"solver_type": "bh2", "solver_type": "bh2",
"final_sigmas_type": "sigma_min",
} }
config.update(**kwargs) config.update(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment