Unverified Commit 2842c14c authored by David El Malih's avatar David El Malih Committed by GitHub
Browse files

Improve docstrings and type hints in scheduling_unipc_multistep.py (#12767)

refactor: add type hints and update docstrings for UniPCMultistepScheduler parameters and methods.
parent c3186860
...@@ -77,7 +77,7 @@ def betas_for_alpha_bar( ...@@ -77,7 +77,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas): def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
""" """
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
...@@ -127,19 +127,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -127,19 +127,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference. The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02): beta_end (`float`, defaults to 0.02):
The final `beta` value. The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, default `2`): solver_order (`int`, defaults to `2`):
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
unconditional sampling. unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper). Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
thresholding (`bool`, defaults to `False`): thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion. as Stable Diffusion.
...@@ -149,7 +149,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -149,7 +149,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, defaults to `True`): predict_x0 (`bool`, defaults to `True`):
Whether to use the updating algorithm on the predicted x0. Whether to use the updating algorithm on the predicted x0.
solver_type (`str`, default `bh2`): solver_type (`"bh1"` or `"bh2"`, defaults to `"bh2"`):
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise. otherwise.
lower_order_final (`bool`, default `True`): lower_order_final (`bool`, default `True`):
...@@ -171,12 +171,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -171,12 +171,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
use_flow_sigmas (`bool`, *optional*, defaults to `False`): use_flow_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families. An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`): final_sigmas_type (`"zero"` or `"sigma_min"`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
rescale_betas_zero_snr (`bool`, defaults to `False`): rescale_betas_zero_snr (`bool`, defaults to `False`):
...@@ -194,30 +194,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -194,30 +194,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2, solver_order: int = 2,
prediction_type: str = "epsilon", prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False, thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
predict_x0: bool = True, predict_x0: bool = True,
solver_type: str = "bh2", solver_type: Literal["bh1", "bh2"] = "bh2",
lower_order_final: bool = True, lower_order_final: bool = True,
disable_corrector: List[int] = [], disable_corrector: List[int] = [],
solver_p: SchedulerMixin = None, solver_p: Optional[SchedulerMixin] = None,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0, flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace", timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False, use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential", time_shift_type: Literal["exponential"] = "exponential",
): ) -> None:
if self.config.use_beta_sigmas and not is_scipy_available(): if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
...@@ -279,21 +279,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -279,21 +279,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
def step_index(self): def step_index(self) -> Optional[int]:
""" """
The index counter for current timestep. It will increase 1 after each scheduler step. The index counter for current timestep. It will increase 1 after each scheduler step.
""" """
return self._step_index return self._step_index
@property @property
def begin_index(self): def begin_index(self) -> Optional[int]:
""" """
The index for the first timestep. It should be set from pipeline with `set_begin_index` method. The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
""" """
return self._begin_index return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0): def set_begin_index(self, begin_index: int = 0) -> None:
""" """
Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
...@@ -304,8 +304,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -304,8 +304,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index self._begin_index = begin_index
def set_timesteps( def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
): ) -> None:
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -314,6 +314,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -314,6 +314,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
mu (`float`, *optional*):
Optional mu parameter for dynamic shifting when using exponential time shift type.
""" """
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None: if mu is not None:
...@@ -475,7 +477,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -475,7 +477,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
""" """
Convert sigma values to corresponding timestep values through interpolation. Convert sigma values to corresponding timestep values through interpolation.
...@@ -512,7 +514,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -512,7 +514,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return t return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Convert sigma values to alpha_t and sigma_t values. Convert sigma values to alpha_t and sigma_t values.
...@@ -534,7 +536,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -534,7 +536,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
""" """
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364). Models](https://huggingface.co/papers/2206.00364).
...@@ -1030,7 +1032,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1030,7 +1032,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep): def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
...@@ -1060,11 +1062,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1060,11 +1062,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.Tensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns: Returns:
...@@ -1192,5 +1194,5 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1192,5 +1194,5 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = alpha_t * original_samples + sigma_t * noise noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self) -> int:
return self.config.num_train_timesteps return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment