Unverified Commit 1c6ede93 authored by hlky's avatar hlky Committed by GitHub
Browse files

[Schedulers] Add beta sigmas / beta noise schedule (#9509)

Add beta sigmas / beta noise schedule
parent aa3c46d9
...@@ -20,11 +20,14 @@ import numpy as np ...@@ -20,11 +20,14 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import BaseOutput, is_scipy_available, logging
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available():
import scipy.stats
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -160,6 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -160,6 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
...@@ -189,6 +195,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,6 +195,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
interpolation_type: str = "linear", interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
sigma_min: Optional[float] = None, sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None, sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
...@@ -197,8 +204,12 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,8 +204,12 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -241,6 +252,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -241,6 +252,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
self.use_exponential_sigmas = use_exponential_sigmas self.use_exponential_sigmas = use_exponential_sigmas
self.use_beta_sigmas = use_beta_sigmas
self._step_index = None self._step_index = None
self._begin_index = None self._begin_index = None
...@@ -340,6 +352,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -340,6 +352,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
if timesteps is not None and self.config.use_exponential_sigmas: if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
if timesteps is not None and self.config.use_beta_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
if ( if (
timesteps is not None timesteps is not None
and self.config.timestep_type == "continuous" and self.config.timestep_type == "continuous"
...@@ -408,6 +422,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -408,6 +422,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.final_sigmas_type == "sigma_min": if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero": elif self.config.final_sigmas_type == "zero":
...@@ -502,6 +520,37 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -502,6 +520,37 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment