Unverified Commit 256e0106 authored by David El Malih's avatar David El Malih Committed by GitHub
Browse files

Improve docstrings and type hints in scheduling_deis_multistep.py (#12796)

* feat: Add `flow_prediction` to `prediction_type`, introduce `use_flow_sigmas`, `flow_shift`, `use_dynamic_shifting`, and `time_shift_type` parameters, and refine type hints for various arguments.

* style: reformat argument wrapping in `_convert_to_beta` and `index_for_timestep` method signatures.
parent 8430ac2a
...@@ -84,33 +84,35 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -84,33 +84,35 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving. methods the library implements for all schedulers such as loading and saving.
Args: Args:
num_train_timesteps (`int`, defaults to 1000): num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model. The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001): beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference. The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02): beta_end (`float`, defaults to `0.02`):
The final `beta` value. The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray` or `List[float]`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, defaults to 2): solver_order (`int`, defaults to `2`):
The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling. sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`): prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper). Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
thresholding (`bool`, defaults to `False`): thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion. as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995): dynamic_thresholding_ratio (`float`, defaults to `0.995`):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0): sample_max_value (`float`, defaults to `1.0`):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`. The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
algorithm_type (`str`, defaults to `deis`): algorithm_type (`"deis"`, defaults to `"deis"`):
The algorithm type for the solver. The algorithm type for the solver.
solver_type (`"logrho"`, defaults to `"logrho"`):
Solver type for DEIS.
lower_order_final (`bool`, defaults to `True`): lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
use_karras_sigmas (`bool`, *optional*, defaults to `False`): use_karras_sigmas (`bool`, *optional*, defaults to `False`):
...@@ -121,11 +123,19 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,11 +123,19 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`): use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
timestep_spacing (`str`, defaults to `"linspace"`): use_flow_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
flow_shift (`float`, *optional*, defaults to `1.0`):
The flow shift parameter for flow-based models.
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families. An offset added to the inference steps, as required by some model families.
use_dynamic_shifting (`bool`, defaults to `False`):
Whether to use dynamic shifting for the noise schedule.
time_shift_type (`"exponential"`, defaults to `"exponential"`):
The type of time shifting to apply.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -137,29 +147,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,29 +147,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2, solver_order: int = 2,
prediction_type: str = "epsilon", prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False, thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
algorithm_type: str = "deis", algorithm_type: Literal["deis"] = "deis",
solver_type: str = "logrho", solver_type: Literal["logrho"] = "logrho",
lower_order_final: bool = True, lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0, flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace", timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
use_dynamic_shifting: bool = False, use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential", time_shift_type: Literal["exponential"] = "exponential",
): ) -> None:
if self.config.use_beta_sigmas and not is_scipy_available(): if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError( raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
) )
...@@ -169,7 +188,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -169,7 +188,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
...@@ -211,21 +238,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -211,21 +238,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
def step_index(self): def step_index(self) -> Optional[int]:
""" """
The index counter for current timestep. It will increase 1 after each scheduler step. The index counter for current timestep. It will increase 1 after each scheduler step.
""" """
return self._step_index return self._step_index
@property @property
def begin_index(self): def begin_index(self) -> Optional[int]:
""" """
The index for the first timestep. It should be set from pipeline with `set_begin_index` method. The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
""" """
return self._begin_index return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0): def set_begin_index(self, begin_index: int = 0) -> None:
""" """
Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
...@@ -236,8 +263,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -236,8 +263,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index self._begin_index = begin_index
def set_timesteps( def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None self,
): num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
) -> None:
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -246,6 +276,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -246,6 +276,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
mu (`float`, *optional*):
The mu parameter for dynamic shifting. Only used when `use_dynamic_shifting=True` and
`time_shift_type="exponential"`.
""" """
if mu is not None: if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
...@@ -363,7 +396,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -363,7 +396,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
""" """
Convert sigma values to corresponding timestep values through interpolation. Convert sigma values to corresponding timestep values through interpolation.
...@@ -400,7 +433,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -400,7 +433,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return t return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Convert sigma values to alpha_t and sigma_t values. Convert sigma values to alpha_t and sigma_t values.
...@@ -422,7 +455,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -422,7 +455,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
""" """
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364). Models](https://huggingface.co/papers/2206.00364).
...@@ -648,7 +681,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -648,7 +681,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
) )
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
...@@ -714,7 +750,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -714,7 +750,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
m0, m1 = model_output_list[-1], model_output_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2]
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 rho_t, rho_s0, rho_s1 = (
sigma_t / alpha_t,
sigma_s0 / alpha_s0,
sigma_s1 / alpha_s1,
)
if self.config.algorithm_type == "deis": if self.config.algorithm_type == "deis":
...@@ -854,7 +894,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -854,7 +894,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep): def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
...@@ -884,18 +924,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -884,18 +924,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.Tensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
raise ValueError( raise ValueError(
...@@ -1000,5 +1039,5 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1000,5 +1039,5 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = alpha_t * original_samples + sigma_t * noise noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self) -> int:
return self.config.num_train_timesteps return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment