Unverified Commit b934215d authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[scheduler] support custom `timesteps` and `sigmas` (#7817)



* support custom sigmas and timesteps, dpm euler

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 5ed3abd3
...@@ -140,6 +140,7 @@ def retrieve_timesteps( ...@@ -140,6 +140,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -155,14 +156,18 @@ def retrieve_timesteps( ...@@ -155,14 +156,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -173,6 +178,16 @@ def retrieve_timesteps( ...@@ -173,6 +178,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -841,6 +856,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -841,6 +856,7 @@ class StableDiffusionXLAdapterPipeline(
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -900,6 +916,10 @@ class StableDiffusionXLAdapterPipeline( ...@@ -900,6 +916,10 @@ class StableDiffusionXLAdapterPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
denoising_end (`float`, *optional*): denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will completed before it is intentionally prematurely terminated. As a result, the returned sample will
...@@ -1101,7 +1121,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1101,7 +1121,9 @@ class StableDiffusionXLAdapterPipeline(
) )
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
......
...@@ -68,7 +68,7 @@ else: ...@@ -68,7 +68,7 @@ else:
_import_structure["scheduling_tcd"] = ["TCDScheduler"] _import_structure["scheduling_tcd"] = ["TCDScheduler"]
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"]
_import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"]
try: try:
...@@ -163,7 +163,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -163,7 +163,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_tcd import TCDScheduler from .scheduling_tcd import TCDScheduler
from .scheduling_unclip import UnCLIPScheduler from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler from .scheduling_vq_diffusion import VQDiffusionScheduler
try: try:
......
...@@ -303,7 +303,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -303,7 +303,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -312,7 +317,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -312,7 +317,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
""" """
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
if timesteps is not None and self.config.use_lu_lambdas:
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64)
else:
# Clipping the minimum of all lambda(t) for numerical stability. # Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule. # This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
...@@ -321,13 +342,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -321,13 +342,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = ( timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) np.linspace(0, last_timestep - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
) )
elif self.config.timestep_spacing == "leading": elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1) step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps = (
(np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
)
timesteps += self.config.steps_offset timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing": elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps step_ratio = self.config.num_train_timesteps / num_inference_steps
......
...@@ -274,7 +274,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -274,7 +274,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -283,8 +288,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -283,8 +288,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
passed, `num_inference_steps` must be `None`.
"""
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
num_inference_steps = num_inference_steps or len(timesteps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64)
else:
# Clipping the minimum of all lambda(t) for numerical stability. # Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule. # This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
......
...@@ -167,6 +167,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -167,6 +167,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -189,6 +192,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,6 +192,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep_type: str = "discrete", # can be "discrete" or "continuous" timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0, steps_offset: int = 0,
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -296,7 +300,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -296,7 +300,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True self.is_scale_input_called = True
return sample return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -305,25 +315,67 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -305,25 +315,67 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
sigmas (`List[float]`, *optional*):
Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
custom sigmas schedule.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` should be set.")
if num_inference_steps is None and timesteps is None and sigmas is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.")
if num_inference_steps is not None and (timesteps is not None or sigmas is not None):
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
if (
timesteps is not None
and self.config.timestep_type == "continuous"
and self.config.prediction_type == "v_prediction"
):
raise ValueError(
"Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
)
if num_inference_steps is None:
num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
if sigmas is not None:
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
sigmas = np.array(sigmas).astype(np.float32)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
else:
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.float32)
else:
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ timesteps = np.linspace(
::-1 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
].copy() )[::-1].copy()
elif self.config.timestep_spacing == "leading": elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) timesteps = (
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
)
timesteps += self.config.steps_offset timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing": elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) timesteps = (
(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
)
timesteps -= 1 timesteps -= 1
else: else:
raise ValueError( raise ValueError(
...@@ -332,7 +384,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -332,7 +384,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas) log_sigmas = np.log(sigmas)
if self.config.interpolation_type == "linear": if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear": elif self.config.interpolation_type == "log_linear":
...@@ -347,18 +398,28 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -347,18 +398,28 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
# TODO: Support the full EDM scalings for all prediction types and timestep types # TODO: Support the full EDM scalings for all prediction types and timestep types
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device) self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
else: else:
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None self._step_index = None
self._begin_index = None self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma # get log sigma
......
...@@ -224,9 +224,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -224,9 +224,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps( def set_timesteps(
self, self,
num_inference_steps: int, num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None, device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None, num_train_timesteps: Optional[int] = None,
timesteps: Optional[List[int]] = None,
): ):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -236,11 +237,28 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -236,11 +237,28 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
num_train_timesteps (`int`, *optional*):
The number of diffusion steps used when training the model. If `None`, the default
`num_train_timesteps` attribute is used.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
must be `None`, and `timestep_spacing` attribute will be ignored.
""" """
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
num_inference_steps = num_inference_steps or len(timesteps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
if timesteps is not None:
timesteps = np.array(timesteps, dtype=np.float32)
else:
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
......
...@@ -48,6 +48,15 @@ class KarrasDiffusionSchedulers(Enum): ...@@ -48,6 +48,15 @@ class KarrasDiffusionSchedulers(Enum):
EDMEulerScheduler = 15 EDMEulerScheduler = 15
AysSchedules = {
"StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24],
"StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0],
"StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13],
"StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0],
"StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0],
}
@dataclass @dataclass
class SchedulerOutput(BaseOutput): class SchedulerOutput(BaseOutput):
""" """
......
...@@ -259,6 +259,44 @@ class StableDiffusionPipelineFastTests( ...@@ -259,6 +259,44 @@ class StableDiffusionPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_ays(self):
from diffusers.schedulers import AysSchedules
timestep_schedule = AysSchedules["StableDiffusionTimesteps"]
sigma_schedule = AysSchedules["StableDiffusionSigmas"]
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe.scheduler = EulerDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 10
output = sd_pipe(**inputs).images
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = None
inputs["timesteps"] = timestep_schedule
output_ts = sd_pipe(**inputs).images
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = None
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
assert (
np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
), "ays timesteps and ays sigmas should have the same outputs"
assert (
np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
), "use ays timesteps should have different outputs"
assert (
np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
), "use ays sigmas should have different outputs"
def test_stable_diffusion_prompt_embeds(self): def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components) sd_pipe = StableDiffusionPipeline(**components)
......
...@@ -214,6 +214,44 @@ class StableDiffusionXLPipelineFastTests( ...@@ -214,6 +214,44 @@ class StableDiffusionXLPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_ays(self):
from diffusers.schedulers import AysSchedules
timestep_schedule = AysSchedules["StableDiffusionXLTimesteps"]
sigma_schedule = AysSchedules["StableDiffusionXLSigmas"]
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe.scheduler = EulerDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 10
output = sd_pipe(**inputs).images
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = None
inputs["timesteps"] = timestep_schedule
output_ts = sd_pipe(**inputs).images
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = None
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
assert (
np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
), "ays timesteps and ays sigmas should have the same outputs"
assert (
np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
), "use ays timesteps should have different outputs"
assert (
np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
), "use ays sigmas should have different outputs"
def test_stable_diffusion_xl_prompt_embeds(self): def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components) sd_pipe = StableDiffusionXLPipeline(**components)
......
...@@ -111,9 +111,33 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -111,9 +111,33 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample = self.dummy_sample_deter sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps)
generator = torch.manual_seed(0)
for i, t in enumerate(scheduler.timesteps): for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
return sample
def full_loop_custom_timesteps(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
# reset the timesteps using `timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
return sample return sample
...@@ -309,10 +333,28 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -309,10 +333,28 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert sample.dtype == torch.float16 assert sample.dtype == torch.float16
def test_duplicated_timesteps(self, **config): def test_duplicated_timesteps(self):
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(scheduler.config.num_train_timesteps) scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps) == scheduler.num_inference_steps assert len(scheduler.timesteps) == scheduler.num_inference_steps
def test_custom_timesteps(self):
for algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for final_sigmas_type in ["sigma_min", "zero"]:
sample = self.full_loop(
algorithm_type=algorithm_type,
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
algorithm_type=algorithm_type,
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
...@@ -103,14 +103,31 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -103,14 +103,31 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
return sample
def full_loop_custom_timesteps(self, **config):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10 num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
# reset the timesteps using`timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps): for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t) residual = model(sample, t)
...@@ -307,3 +324,21 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -307,3 +324,21 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 269.2187) < 1e-2, f" expected result sum 269.2187, but get {result_sum}" assert abs(result_sum.item() - 269.2187) < 1e-2, f" expected result sum 269.2187, but get {result_sum}"
assert abs(result_mean.item() - 0.3505) < 1e-3, f" expected result mean 0.3505, but get {result_mean}" assert abs(result_mean.item() - 0.3505) < 1e-3, f" expected result mean 0.3505, but get {result_mean}"
def test_custom_timesteps(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for lower_order_final in [True, False]:
for final_sigmas_type in ["sigma_min", "zero"]:
sample = self.full_loop(
prediction_type=prediction_type,
lower_order_final=lower_order_final,
final_sigmas_type=final_sigmas_type,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
prediction_type=prediction_type,
lower_order_final=lower_order_final,
final_sigmas_type=final_sigmas_type,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
...@@ -49,12 +49,13 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -49,12 +49,13 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
for rescale_betas_zero_snr in [True, False]: for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
def test_full_loop_no_noise(self): def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps) num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -69,19 +70,46 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -69,19 +70,46 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
output = scheduler.step(model_output, t, sample, generator=generator) output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample sample = output.prev_sample
return sample
result_sum = torch.sum(torch.abs(sample)) def full_loop_custom_timesteps(self, **config):
result_mean = torch.mean(torch.abs(sample)) scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
assert abs(result_sum.item() - 10.0807) < 1e-2 num_inference_steps = self.num_inference_steps
assert abs(result_mean.item() - 0.0131) < 1e-3 scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
# reset the timesteps using `timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
def test_full_loop_with_v_prediction(self): generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
return sample
def full_loop_custom_sigmas(self, **config):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps) num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
sigmas = scheduler.sigmas
# reset the timesteps using `sigmas`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, sigmas=sigmas)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -96,6 +124,19 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -96,6 +124,19 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
output = scheduler.step(model_output, t, sample, generator=generator) output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample sample = output.prev_sample
return sample
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
...@@ -189,3 +230,36 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -189,3 +230,36 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 57062.9297) < 1e-2, f" expected result sum 57062.9297, but get {result_sum}" assert abs(result_sum.item() - 57062.9297) < 1e-2, f" expected result sum 57062.9297, but get {result_sum}"
assert abs(result_mean.item() - 74.3007) < 1e-3, f" expected result mean 74.3007, but get {result_mean}" assert abs(result_mean.item() - 74.3007) < 1e-3, f" expected result mean 74.3007, but get {result_mean}"
def test_custom_timesteps(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for interpolation_type in ["linear", "log_linear"]:
for final_sigmas_type in ["sigma_min", "zero"]:
sample = self.full_loop(
prediction_type=prediction_type,
interpolation_type=interpolation_type,
final_sigmas_type=final_sigmas_type,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
prediction_type=prediction_type,
interpolation_type=interpolation_type,
final_sigmas_type=final_sigmas_type,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
def test_custom_sigmas(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for final_sigmas_type in ["sigma_min", "zero"]:
sample = self.full_loop(
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
sample_custom_timesteps = self.full_loop_custom_sigmas(
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
...@@ -41,12 +41,13 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -41,12 +41,13 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
for prediction_type in ["epsilon", "v_prediction", "sample"]: for prediction_type in ["epsilon", "v_prediction", "sample"]:
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_full_loop_no_noise(self): def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps) num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma
...@@ -60,23 +61,20 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -60,23 +61,20 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
output = scheduler.step(model_output, t, sample) output = scheduler.step(model_output, t, sample)
sample = output.prev_sample sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample)) return sample
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_with_v_prediction(self): def full_loop_custom_timesteps(self, **config):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps) num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
timesteps = torch.cat([timesteps[:1], timesteps[1::2]])
# reset the timesteps using `timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma
...@@ -90,6 +88,23 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -90,6 +88,23 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
output = scheduler.step(model_output, t, sample) output = scheduler.step(model_output, t, sample)
sample = output.prev_sample sample = output.prev_sample
return sample
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
...@@ -189,3 +204,18 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -189,3 +204,18 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 75074.8906) < 1e-2, f" expected result sum 75074.8906, but get {result_sum}" assert abs(result_sum.item() - 75074.8906) < 1e-2, f" expected result sum 75074.8906, but get {result_sum}"
assert abs(result_mean.item() - 97.7538) < 1e-3, f" expected result mean 97.7538, but get {result_mean}" assert abs(result_mean.item() - 97.7538) < 1e-3, f" expected result mean 97.7538, but get {result_mean}"
def test_custom_timesteps(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for timestep_spacing in ["linspace", "leading"]:
sample = self.full_loop(
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment