"docs/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "169d088adb755a237a64bf70973374768ea1fc50"
Unverified Commit c4a8979f authored by hlky's avatar hlky Committed by GitHub
Browse files

Add beta sigmas to other schedulers and update docs (#9538)

parent f9fd5114
...@@ -52,6 +52,7 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso ...@@ -52,6 +52,7 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
| sgm_uniform | init with `timestep_spacing="trailing"` | | sgm_uniform | init with `timestep_spacing="trailing"` |
| simple | init with `timestep_spacing="trailing"` | | simple | init with `timestep_spacing="trailing"` |
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` | | exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` |
All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers. All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.
......
...@@ -22,10 +22,14 @@ import numpy as np ...@@ -22,10 +22,14 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate from ..utils import deprecate, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -113,6 +117,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,6 +117,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
...@@ -141,11 +148,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -141,11 +148,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final: bool = True, lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -263,6 +275,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -263,6 +275,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
...@@ -396,6 +411,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -396,6 +411,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
......
...@@ -21,11 +21,15 @@ import numpy as np ...@@ -21,11 +21,15 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate from ..utils import deprecate, is_scipy_available
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
use_lu_lambdas (`bool`, *optional*, defaults to `False`): use_lu_lambdas (`bool`, *optional*, defaults to `False`):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
...@@ -209,6 +216,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -209,6 +216,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
euler_at_final: bool = False, euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
...@@ -217,8 +225,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -217,8 +225,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0, steps_offset: int = 0,
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
...@@ -337,6 +349,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -337,6 +349,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
if timesteps is not None and self.config.use_exponential_sigmas: if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
if timesteps is not None and self.config.use_beta_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
if timesteps is not None: if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64) timesteps = np.array(timesteps).astype(np.int64)
...@@ -388,6 +402,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -388,6 +402,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
...@@ -542,6 +559,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -542,6 +559,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
......
...@@ -21,11 +21,15 @@ import numpy as np ...@@ -21,11 +21,15 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate from ..utils import deprecate, is_scipy_available
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -126,6 +130,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -126,6 +130,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
lambda_min_clipped (`float`, defaults to `-inf`): lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule. cosine (`squaredcos_cap_v2`) noise schedule.
...@@ -161,13 +168,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -161,13 +168,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
euler_at_final: bool = False, euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
...@@ -219,6 +231,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -219,6 +231,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
self.use_exponential_sigmas = use_exponential_sigmas self.use_exponential_sigmas = use_exponential_sigmas
self.use_beta_sigmas = use_beta_sigmas
@property @property
def step_index(self): def step_index(self):
...@@ -276,6 +289,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -276,6 +289,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = ( sigma_max = (
...@@ -416,6 +432,38 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -416,6 +432,38 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
def convert_model_output( def convert_model_output(
self, self,
......
...@@ -20,9 +20,14 @@ import torch ...@@ -20,9 +20,14 @@ import torch
import torchsde import torchsde
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
class BatchedBrownianTree: class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy.""" """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
...@@ -162,6 +167,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -162,6 +167,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
noise_sampler_seed (`int`, *optional*, defaults to `None`): noise_sampler_seed (`int`, *optional*, defaults to `None`):
The random seed to use for the noise sampler. If `None`, a random seed is generated. The random seed to use for the noise sampler. If `None`, a random seed is generated.
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`str`, defaults to `"linspace"`):
...@@ -185,12 +193,17 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -185,12 +193,17 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
noise_sampler_seed: Optional[int] = None, noise_sampler_seed: Optional[int] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -349,6 +362,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -349,6 +362,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
...@@ -451,6 +467,38 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -451,6 +467,38 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
@property @property
def state_in_first_order(self): def state_in_first_order(self):
return self.sample is None return self.sample is None
......
...@@ -21,11 +21,14 @@ import numpy as np ...@@ -21,11 +21,14 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate, logging from ..utils import deprecate, is_scipy_available, logging
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -125,6 +128,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -125,6 +128,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
final_sigmas_type (`str`, *optional*, defaults to `"zero"`): final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
...@@ -157,12 +163,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -157,12 +163,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final: bool = False, lower_order_final: bool = False,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type == "dpmsolver": if algorithm_type == "dpmsolver":
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
...@@ -307,6 +318,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -307,6 +318,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.") raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
if timesteps is not None and self.config.use_exponential_sigmas: if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
if timesteps is not None and self.config.use_beta_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
num_inference_steps = num_inference_steps or len(timesteps) num_inference_steps = num_inference_steps or len(timesteps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
...@@ -333,6 +346,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -333,6 +346,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
...@@ -484,6 +500,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -484,6 +500,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
......
...@@ -19,9 +19,14 @@ import numpy as np ...@@ -19,9 +19,14 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -99,6 +104,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,6 +104,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
...@@ -120,13 +128,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -120,13 +128,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
clip_sample: Optional[bool] = False, clip_sample: Optional[bool] = False,
clip_sample_range: float = 1.0, clip_sample_range: float = 1.0,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -258,6 +271,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -258,6 +271,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
if timesteps is not None and self.config.use_exponential_sigmas: if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
if timesteps is not None and self.config.use_beta_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
num_inference_steps = num_inference_steps or len(timesteps) num_inference_steps = num_inference_steps or len(timesteps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
...@@ -296,6 +311,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -296,6 +311,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device) sigmas = torch.from_numpy(sigmas).to(device=device)
...@@ -386,6 +404,38 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -386,6 +404,38 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
@property @property
def state_in_first_order(self): def state_in_first_order(self):
return self.dt is None return self.dt is None
......
...@@ -19,10 +19,15 @@ import numpy as np ...@@ -19,10 +19,15 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -93,6 +98,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -93,6 +98,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
...@@ -117,12 +125,17 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,12 +125,17 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -258,6 +271,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -258,6 +271,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
self.log_sigmas = torch.from_numpy(log_sigmas).to(device) self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
...@@ -376,6 +392,38 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -376,6 +392,38 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
@property @property
def state_in_first_order(self): def state_in_first_order(self):
return self.sample is None return self.sample is None
......
...@@ -19,9 +19,14 @@ import numpy as np ...@@ -19,9 +19,14 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -92,6 +97,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -92,6 +97,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
...@@ -116,12 +124,17 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -116,12 +124,17 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -257,6 +270,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -257,6 +270,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device) self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
...@@ -389,6 +405,38 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -389,6 +405,38 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def step( def step(
self, self,
model_output: Union[torch.Tensor, np.ndarray], model_output: Union[torch.Tensor, np.ndarray],
......
...@@ -17,6 +17,7 @@ from dataclasses import dataclass ...@@ -17,6 +17,7 @@ from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import scipy.stats
import torch import torch
from scipy import integrate from scipy import integrate
...@@ -113,6 +114,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,6 +114,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
...@@ -137,12 +141,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,12 +141,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -297,6 +304,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -297,6 +304,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
...@@ -392,6 +402,38 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -392,6 +402,38 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
......
...@@ -22,11 +22,15 @@ import numpy as np ...@@ -22,11 +22,15 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate from ..utils import deprecate, is_scipy_available
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -124,6 +128,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -124,6 +128,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
lambda_min_clipped (`float`, defaults to `-inf`): lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule. cosine (`squaredcos_cap_v2`) noise schedule.
...@@ -159,13 +166,18 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -159,13 +166,18 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
lower_order_final: bool = True, lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -292,6 +304,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -292,6 +304,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
...@@ -425,6 +440,38 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -425,6 +440,38 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
......
...@@ -22,10 +22,14 @@ import numpy as np ...@@ -22,10 +22,14 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate from ..utils import deprecate, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -161,6 +165,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -161,6 +165,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}. the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`): use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
...@@ -198,13 +205,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -198,13 +205,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_p: SchedulerMixin = None, solver_p: SchedulerMixin = None,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
): ):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if self.config.use_beta_sigmas and not is_scipy_available():
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -337,6 +349,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -337,6 +349,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min": if self.config.final_sigmas_type == "sigma_min":
...@@ -480,6 +495,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -480,6 +495,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment