Unverified Commit b934215d authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[scheduler] support custom `timesteps` and `sigmas` (#7817)



* support custom sigmas and timesteps, dpm euler

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 5ed3abd3
......@@ -140,6 +140,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
......@@ -155,14 +156,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
......@@ -173,6 +178,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
......@@ -841,6 +856,7 @@ class StableDiffusionXLAdapterPipeline(
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
......@@ -900,6 +916,10 @@ class StableDiffusionXLAdapterPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
......@@ -1101,7 +1121,9 @@ class StableDiffusionXLAdapterPipeline(
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
......
......@@ -68,7 +68,7 @@ else:
_import_structure["scheduling_tcd"] = ["TCDScheduler"]
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
_import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"]
_import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"]
try:
......@@ -163,7 +163,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_tcd import TCDScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
try:
......
......@@ -303,7 +303,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
......@@ -312,33 +317,54 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
"""
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
if timesteps is not None and self.config.use_lu_lambdas:
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64)
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
......
......@@ -274,7 +274,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
......@@ -283,17 +288,33 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
passed, `num_inference_steps` must be `None`.
"""
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
num_inference_steps = num_inference_steps or len(timesteps)
self.num_inference_steps = num_inference_steps
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64)
else:
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
......
......@@ -167,6 +167,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
......@@ -189,6 +192,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
......@@ -296,7 +300,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
......@@ -305,60 +315,111 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
sigmas (`List[float]`, *optional*):
Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
custom sigmas schedule.
"""
self.num_inference_steps = num_inference_steps
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` should be set.")
if num_inference_steps is None and timesteps is None and sigmas is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.")
if num_inference_steps is not None and (timesteps is not None or sigmas is not None):
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
if (
timesteps is not None
and self.config.timestep_type == "continuous"
and self.config.prediction_type == "v_prediction"
):
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
"Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if num_inference_steps is None:
num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
self.num_inference_steps = num_inference_steps
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
else:
raise ValueError(
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
" 'linear' or 'log_linear'"
)
if sigmas is not None:
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
sigmas = np.array(sigmas).astype(np.float32)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.float32)
else:
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
else:
raise ValueError(
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
" 'linear' or 'log_linear'"
)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
# TODO: Support the full EDM scalings for all prediction types and timestep types
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
else:
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
......
......@@ -224,9 +224,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
timesteps: Optional[List[int]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
......@@ -236,30 +237,47 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
num_train_timesteps (`int`, *optional*):
The number of diffusion steps used when training the model. If `None`, the default
`num_train_timesteps` attribute is used.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
must be `None`, and `timestep_spacing` attribute will be ignored.
"""
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
num_inference_steps = num_inference_steps or len(timesteps)
self.num_inference_steps = num_inference_steps
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
if timesteps is not None:
timesteps = np.array(timesteps, dtype=np.float32)
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
......
......@@ -48,6 +48,15 @@ class KarrasDiffusionSchedulers(Enum):
EDMEulerScheduler = 15
AysSchedules = {
"StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24],
"StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0],
"StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13],
"StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0],
"StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0],
}
@dataclass
class SchedulerOutput(BaseOutput):
"""
......
......@@ -259,6 +259,44 @@ class StableDiffusionPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_ays(self):
from diffusers.schedulers import AysSchedules
timestep_schedule = AysSchedules["StableDiffusionTimesteps"]
sigma_schedule = AysSchedules["StableDiffusionSigmas"]
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe.scheduler = EulerDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 10
output = sd_pipe(**inputs).images
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = None
inputs["timesteps"] = timestep_schedule
output_ts = sd_pipe(**inputs).images
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = None
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
assert (
np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
), "ays timesteps and ays sigmas should have the same outputs"
assert (
np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
), "use ays timesteps should have different outputs"
assert (
np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
), "use ays sigmas should have different outputs"
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
......
......@@ -214,6 +214,44 @@ class StableDiffusionXLPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_ays(self):
from diffusers.schedulers import AysSchedules
timestep_schedule = AysSchedules["StableDiffusionXLTimesteps"]
sigma_schedule = AysSchedules["StableDiffusionXLSigmas"]
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe.scheduler = EulerDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 10
output = sd_pipe(**inputs).images
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = None
inputs["timesteps"] = timestep_schedule
output_ts = sd_pipe(**inputs).images
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = None
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
assert (
np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
), "ays timesteps and ays sigmas should have the same outputs"
assert (
np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
), "use ays timesteps should have different outputs"
assert (
np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
), "use ays sigmas should have different outputs"
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
......
......@@ -111,9 +111,33 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
generator = torch.manual_seed(0)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
return sample
def full_loop_custom_timesteps(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
# reset the timesteps using `timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
return sample
......@@ -309,10 +333,28 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert sample.dtype == torch.float16
def test_duplicated_timesteps(self, **config):
def test_duplicated_timesteps(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps) == scheduler.num_inference_steps
def test_custom_timesteps(self):
for algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for final_sigmas_type in ["sigma_min", "zero"]:
sample = self.full_loop(
algorithm_type=algorithm_type,
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
algorithm_type=algorithm_type,
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
......@@ -103,14 +103,31 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
return sample
def full_loop_custom_timesteps(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
# reset the timesteps using`timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
......@@ -307,3 +324,21 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 269.2187) < 1e-2, f" expected result sum 269.2187, but get {result_sum}"
assert abs(result_mean.item() - 0.3505) < 1e-3, f" expected result mean 0.3505, but get {result_mean}"
def test_custom_timesteps(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for lower_order_final in [True, False]:
for final_sigmas_type in ["sigma_min", "zero"]:
sample = self.full_loop(
prediction_type=prediction_type,
lower_order_final=lower_order_final,
final_sigmas_type=final_sigmas_type,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
prediction_type=prediction_type,
lower_order_final=lower_order_final,
final_sigmas_type=final_sigmas_type,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
......@@ -49,12 +49,13 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
def test_full_loop_no_noise(self):
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
generator = torch.manual_seed(0)
......@@ -69,19 +70,46 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
return sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
def full_loop_custom_timesteps(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3
num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
# reset the timesteps using `timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
def test_full_loop_with_v_prediction(self):
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
return sample
def full_loop_custom_sigmas(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
sigmas = scheduler.sigmas
# reset the timesteps using `sigmas`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, sigmas=sigmas)
generator = torch.manual_seed(0)
......@@ -96,6 +124,19 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
return sample
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
......@@ -189,3 +230,36 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 57062.9297) < 1e-2, f" expected result sum 57062.9297, but get {result_sum}"
assert abs(result_mean.item() - 74.3007) < 1e-3, f" expected result mean 74.3007, but get {result_mean}"
def test_custom_timesteps(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for interpolation_type in ["linear", "log_linear"]:
for final_sigmas_type in ["sigma_min", "zero"]:
sample = self.full_loop(
prediction_type=prediction_type,
interpolation_type=interpolation_type,
final_sigmas_type=final_sigmas_type,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
prediction_type=prediction_type,
interpolation_type=interpolation_type,
final_sigmas_type=final_sigmas_type,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
def test_custom_sigmas(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for final_sigmas_type in ["sigma_min", "zero"]:
sample = self.full_loop(
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
sample_custom_timesteps = self.full_loop_custom_sigmas(
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
......@@ -41,12 +41,13 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
for prediction_type in ["epsilon", "v_prediction", "sample"]:
self.check_over_configs(prediction_type=prediction_type)
def test_full_loop_no_noise(self):
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -60,23 +61,20 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
return sample
def test_full_loop_with_v_prediction(self):
def full_loop_custom_timesteps(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
timesteps = torch.cat([timesteps[:1], timesteps[1::2]])
# reset the timesteps using `timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -90,6 +88,23 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
return sample
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
......@@ -189,3 +204,18 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 75074.8906) < 1e-2, f" expected result sum 75074.8906, but get {result_sum}"
assert abs(result_mean.item() - 97.7538) < 1e-3, f" expected result mean 97.7538, but get {result_mean}"
def test_custom_timesteps(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for timestep_spacing in ["linspace", "leading"]:
sample = self.full_loop(
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment