Unverified Commit 3706aa33 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Add rescale_betas_zero_snr Argument to DDPMScheduler (#6305)



* Add rescale_betas_zero_snr argument to DDPMScheduler.

* Propagate rescale_betas_zero_snr changes to DDPMParallelScheduler.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d4f10ea3
...@@ -89,6 +89,43 @@ def betas_for_alpha_bar( ...@@ -89,6 +89,43 @@ def betas_for_alpha_bar(
return torch.tensor(betas, dtype=torch.float32) return torch.tensor(betas, dtype=torch.float32)
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class DDPMScheduler(SchedulerMixin, ConfigMixin): class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
`DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
...@@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
An offset added to the inference steps. You can use a combination of `offset=1` and An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion. Diffusion.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -153,6 +194,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -153,6 +194,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
timestep_spacing: str = "leading", timestep_spacing: str = "leading",
steps_offset: int = 0, steps_offset: int = 0,
rescale_betas_zero_snr: int = False,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -171,6 +213,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -171,6 +213,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0) self.one = torch.tensor(1.0)
......
...@@ -91,6 +91,43 @@ def betas_for_alpha_bar( ...@@ -91,6 +91,43 @@ def betas_for_alpha_bar(
return torch.tensor(betas, dtype=torch.float32) return torch.tensor(betas, dtype=torch.float32)
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
""" """
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
...@@ -139,6 +176,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,6 +176,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion. stable diffusion.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -163,6 +204,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -163,6 +204,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
timestep_spacing: str = "leading", timestep_spacing: str = "leading",
steps_offset: int = 0, steps_offset: int = 0,
rescale_betas_zero_snr: int = False,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -181,6 +223,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -181,6 +223,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0) self.one = torch.tensor(1.0)
......
...@@ -68,6 +68,10 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -68,6 +68,10 @@ class DDPMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
......
...@@ -82,6 +82,10 @@ class DDPMParallelSchedulerTest(SchedulerCommonTest): ...@@ -82,6 +82,10 @@ class DDPMParallelSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
def test_batch_step_no_noise(self): def test_batch_step_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment