Unverified Commit a88a7b4f authored by David El Malih's avatar David El Malih Committed by GitHub
Browse files

Improve docstrings and type hints in scheduling_dpmsolver_multistep.py (#12710)

* Improve docstrings and type hints in multiple diffusion schedulers

* docs: update Imagen Video paper link to Hugging Face Papers.
parent c8656ed7
...@@ -429,7 +429,22 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -429,7 +429,22 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -452,6 +467,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -452,6 +467,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
""" """
if self.begin_index is None: if self.begin_index is None:
......
...@@ -401,6 +401,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -401,6 +401,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas: if self.config.use_flow_sigmas:
alpha_t = 1 - sigma alpha_t = 1 - sigma
sigma_t = sigma sigma_t = sigma
...@@ -808,7 +819,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -808,7 +819,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError("only support log-rho multistep deis now") raise NotImplementedError("only support log-rho multistep deis now")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -831,6 +857,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -831,6 +857,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
""" """
if self.begin_index is None: if self.begin_index is None:
...@@ -927,6 +957,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -927,6 +957,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference. The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02): beta_end (`float`, defaults to 0.02):
The final `beta` value. The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, defaults to 2): solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling. sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`. Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
thresholding (`bool`, defaults to `False`): thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion. as Stable Diffusion.
...@@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sample_max_value (`float`, defaults to 1.0): sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`. `algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`): algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
paper, and the `dpmsolver++` type implements the algorithms in the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the for a small number of steps. It is recommended to use `midpoint` solvers.
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
lower_order_final (`bool`, defaults to `True`): lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
...@@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
flow_shift (`float`, *optional*, defaults to 1.0): flow_shift (`float`, *optional*, defaults to 1.0):
The shift value for the timestep schedule for flow matching. The shift value for the timestep schedule for flow matching.
final_sigmas_type (`str`, defaults to `"zero"`): final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`): lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule. cosine (`squaredcos_cap_v2`) noise schedule.
variance_type (`str`, *optional*): variance_type (`"learned"` or `"learned_range"`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
contains the predicted Gaussian variance. output contains the predicted Gaussian variance.
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): steps_offset (`int`, defaults to 0):
...@@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
use_dynamic_shifting (`bool`, defaults to `False`):
Whether to use dynamic shifting for the timestep schedule.
time_shift_type (`"exponential"`, defaults to `"exponential"`):
The type of time shift to apply when using dynamic shifting.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -208,15 +210,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -208,15 +210,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2, solver_order: int = 2,
prediction_type: str = "epsilon", prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False, thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++", algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
solver_type: str = "midpoint", solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
euler_at_final: bool = False, euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
...@@ -225,14 +227,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -225,14 +227,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
use_lu_lambdas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0, flow_shift: Optional[float] = 1.0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[Literal["learned", "learned_range"]] = None,
timestep_spacing: str = "linspace", timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False, use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential", time_shift_type: Literal["exponential"] = "exponential",
): ):
if self.config.use_beta_sigmas and not is_scipy_available(): if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
...@@ -331,19 +333,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -331,19 +333,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps( def set_timesteps(
self, self,
num_inference_steps: int = None, num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None, device: Optional[Union[str, torch.device]] = None,
mu: Optional[float] = None, mu: Optional[float] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
): ) -> None:
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
mu (`float`, *optional*):
The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
`time_shift_type="exponential"`.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
...@@ -503,7 +508,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -503,7 +508,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
""" """
Convert sigma values to corresponding timestep values through interpolation. Convert sigma values to corresponding timestep values through interpolation.
...@@ -539,7 +544,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -539,7 +544,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape) t = t.reshape(sigma.shape)
return t return t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas: if self.config.use_flow_sigmas:
alpha_t = 1 - sigma alpha_t = 1 - sigma
sigma_t = sigma sigma_t = sigma
...@@ -588,8 +604,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -588,8 +604,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas return sigmas
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs the noise schedule of Lu et al. (2022).""" """
Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
Args:
in_lambdas (`torch.Tensor`):
The input lambda values to be converted.
num_inference_steps (`int`):
The number of inference steps to generate the noise schedule for.
Returns:
`torch.Tensor`:
The converted lambda values following the Lu noise schedule.
"""
lambda_min: float = in_lambdas[-1].item() lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item() lambda_max: float = in_lambdas[0].item()
...@@ -1069,7 +1098,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1069,7 +1098,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
) )
return x_t return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -1088,9 +1132,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1088,9 +1132,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index return step_index
def _init_step_index(self, timestep): def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
""" """
if self.begin_index is None: if self.begin_index is None:
...@@ -1105,7 +1153,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1105,7 +1153,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: Union[int, torch.Tensor], timestep: Union[int, torch.Tensor],
sample: torch.Tensor, sample: torch.Tensor,
generator=None, generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None, variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
...@@ -1115,22 +1163,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1115,22 +1163,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.Tensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`): timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
variance_noise (`torch.Tensor`): variance_noise (`torch.Tensor`, *optional*):
Alternative to generating noise with `generator` by directly providing the noise for the variance Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`]. itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`): return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
...@@ -1210,6 +1258,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1210,6 +1258,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -413,6 +413,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -413,6 +413,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas: if self.config.use_flow_sigmas:
alpha_t = 1 - sigma alpha_t = 1 - sigma
sigma_t = sigma sigma_t = sigma
......
...@@ -491,6 +491,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -491,6 +491,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas: if self.config.use_flow_sigmas:
alpha_t = 1 - sigma alpha_t = 1 - sigma
sigma_t = sigma sigma_t = sigma
...@@ -1079,7 +1090,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1079,7 +1090,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"Order must be 1, 2, 3, got {order}") raise ValueError(f"Order must be 1, 2, 3, got {order}")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -1102,6 +1128,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1102,6 +1128,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
""" """
if self.begin_index is None: if self.begin_index is None:
...@@ -1204,6 +1234,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1204,6 +1234,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -578,7 +578,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -578,7 +578,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -601,6 +616,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -601,6 +616,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
""" """
if self.begin_index is None: if self.begin_index is None:
......
...@@ -423,6 +423,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -423,6 +423,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas: if self.config.use_flow_sigmas:
alpha_t = 1 - sigma alpha_t = 1 - sigma
sigma_t = sigma sigma_t = sigma
...@@ -1103,7 +1114,22 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -1103,7 +1114,22 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
return x_t return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -1126,6 +1152,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -1126,6 +1152,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
""" """
if self.begin_index is None: if self.begin_index is None:
......
...@@ -513,6 +513,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -513,6 +513,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas: if self.config.use_flow_sigmas:
alpha_t = 1 - sigma alpha_t = 1 - sigma
sigma_t = sigma sigma_t = sigma
...@@ -984,7 +995,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -984,7 +995,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -1007,6 +1033,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1007,6 +1033,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
""" """
Initialize the step_index counter for the scheduler. Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
""" """
if self.begin_index is None: if self.begin_index is None:
...@@ -1119,6 +1149,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1119,6 +1149,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment