Unverified Commit aa190259 authored by Beinsezii's avatar Beinsezii Committed by GitHub
Browse files

UniPC Multistep add `rescale_betas_zero_snr` (#7531)

* UniPC Multistep add `rescale_betas_zero_snr`

Same patch as DPM and Euler with the patched final alpha cumprod

BF16 doesn't seem to break down, I think cause UniPC upcasts during some
phases already? We could still force an upcast since it only
loses ≈ 0.005 it/s for me but the difference in output is very small. A
better endeavor might upcasting in step() and removing all the other
upcasts elsewhere?

* UniPC ZSNR UT

* Re-add `rescale_betas_zsnr` doc oops
parent 19ab04ff
...@@ -71,6 +71,43 @@ def betas_for_alpha_bar( ...@@ -71,6 +71,43 @@ def betas_for_alpha_bar(
return torch.tensor(betas, dtype=torch.float32) return torch.tensor(betas, dtype=torch.float32)
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
...@@ -130,6 +167,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -130,6 +167,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
final_sigmas_type (`str`, defaults to `"zero"`): final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -157,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -157,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -171,8 +213,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -171,8 +213,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
if rescale_betas_zero_snr:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2**-24
# Currently we only support VP-type noise schedule # Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod) self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
......
...@@ -180,6 +180,10 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -180,6 +180,10 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
for prediction_type in ["epsilon", "v_prediction"]: for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
def test_solver_order_and_type(self): def test_solver_order_and_type(self):
for solver_type in ["bh1", "bh2"]: for solver_type in ["bh1", "bh2"]:
for order in [1, 2, 3]: for order in [1, 2, 3]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment