Unverified Commit a88a7b4f authored by David El Malih's avatar David El Malih Committed by GitHub
Browse files

Improve docstrings and type hints in scheduling_dpmsolver_multistep.py (#12710)

* Improve docstrings and type hints in multiple diffusion schedulers

* docs: update Imagen Video paper link to Hugging Face Papers.
parent c8656ed7
......@@ -429,7 +429,22 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
......@@ -452,6 +467,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
......
......@@ -401,6 +401,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
......@@ -808,7 +819,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError("only support log-rho multistep deis now")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
......@@ -831,6 +857,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
......@@ -927,6 +957,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
......@@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
paper, and the `dpmsolver++` type implements the algorithms in the
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
[DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
for a small number of steps. It is recommended to use `midpoint` solvers.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
......@@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
flow_shift (`float`, *optional*, defaults to 1.0):
The shift value for the timestep schedule for flow matching.
final_sigmas_type (`str`, defaults to `"zero"`):
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
contains the predicted Gaussian variance.
timestep_spacing (`str`, defaults to `"linspace"`):
variance_type (`"learned"` or `"learned_range"`, *optional*):
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
output contains the predicted Gaussian variance.
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
......@@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
use_dynamic_shifting (`bool`, defaults to `False`):
Whether to use dynamic shifting for the timestep schedule.
time_shift_type (`"exponential"`, defaults to `"exponential"`):
The type of time shift to apply when using dynamic shifting.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
......@@ -208,15 +210,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
......@@ -225,14 +227,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
use_lu_lambdas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
variance_type: Optional[Literal["learned", "learned_range"]] = None,
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
time_shift_type: Literal["exponential"] = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
......@@ -331,19 +333,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
):
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
mu (`float`, *optional*):
The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
`time_shift_type="exponential"`.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
......@@ -503,7 +508,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
"""
Convert sigma values to corresponding timestep values through interpolation.
......@@ -539,7 +544,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
......@@ -588,8 +604,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Lu et al. (2022)."""
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
Args:
in_lambdas (`torch.Tensor`):
The input lambda values to be converted.
num_inference_steps (`int`):
The number of inference steps to generate the noise schedule for.
Returns:
`torch.Tensor`:
The converted lambda values following the Lu noise schedule.
"""
lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()
......@@ -1069,7 +1098,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
......@@ -1088,9 +1132,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
......@@ -1105,7 +1153,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
......@@ -1115,22 +1163,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The direct output from the learned diffusion model.
timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
variance_noise (`torch.Tensor`, *optional*):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
......@@ -1210,6 +1258,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -413,6 +413,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
......
......@@ -491,6 +491,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
......@@ -1079,7 +1090,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"Order must be 1, 2, 3, got {order}")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
......@@ -1102,6 +1128,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
......@@ -1204,6 +1234,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -578,7 +578,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
......@@ -601,6 +616,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
......
......@@ -423,6 +423,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
......@@ -1103,7 +1114,22 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
......@@ -1126,6 +1152,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
......
......@@ -513,6 +513,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
......@@ -984,7 +995,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
......@@ -1007,6 +1033,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
......@@ -1119,6 +1149,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment