Unverified Commit 2e31a759 authored by Beinsezii's avatar Beinsezii Committed by GitHub
Browse files

DPMSolverMultistep add `rescale_betas_zero_snr` (#7097)

* DPMMultistep rescale_betas_zero_snr

* DPM upcast samples in step()

* DPM rescale_betas_zero_snr UT

* DPMSolverMulti move sample upcast after model convert

Avoids having to re-use the dtype.

* Add a newline for Ruff
parent e51862bb
...@@ -71,6 +71,43 @@ def betas_for_alpha_bar( ...@@ -71,6 +71,43 @@ def betas_for_alpha_bar(
return torch.tensor(betas, dtype=torch.float32) return torch.tensor(betas, dtype=torch.float32)
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
...@@ -144,6 +181,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -144,6 +181,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
An offset added to the inference steps. You can use a combination of `offset=1` and An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion. Diffusion.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -173,6 +214,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -173,6 +214,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
): ):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
...@@ -191,8 +233,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -191,8 +233,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
if rescale_betas_zero_snr:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2**-24
# Currently we only support VP-type noise schedule # Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod) self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
...@@ -895,9 +946,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -895,9 +946,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output self.model_outputs[-1] = model_output
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = randn_tensor( noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
) )
else: else:
noise = None noise = None
...@@ -912,6 +966,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -912,6 +966,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.lower_order_nums < self.config.solver_order: if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1 self.lower_order_nums += 1
# Cast sample back to expected dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one # upon completion increase step index by one
self._step_index += 1 self._step_index += 1
......
...@@ -213,6 +213,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -213,6 +213,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
sample = self.full_loop() sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment