Unverified Commit 16ad13b6 authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] Clean scheduler api (#4204)

* clean scheduler mixin

* up to dpmsolvermultistep

* finish cleaning

* first draft

* fix overview table

* apply feedback

* update reference code
parent da0e2fce
...@@ -71,36 +71,35 @@ def betas_for_alpha_bar( ...@@ -71,36 +71,35 @@ def betas_for_alpha_bar(
class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: KDPM2DiscreteScheduler with ancestral sampling is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating
https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper.
Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the num_train_timesteps (`int`, defaults to 1000):
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): The number of diffusion steps to train the model.
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from beta_start (`float`, defaults to 0.00085):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.012):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`. `linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, *optional*):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, prediction_type (`str`, defaults to `epsilon`, *optional*):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
prediction_type (`str`, default `epsilon`, optional): `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion Video](https://imagen.research.google/video/paper.pdf) paper).
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 timestep_spacing (`str`, defaults to `"linspace"`):
https://imagen.research.google/video/paper.pdf) The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
timestep_spacing (`str`, default `"linspace"`): Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample steps_offset (`int`, defaults to 0):
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. An offset added to the inference steps. You can use a combination of `offset=1` and
steps_offset (`int`, default `0`): `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
an offset added to the inference steps. You can use a combination of `offset=1` and Diffusion.
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -172,12 +171,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -172,12 +171,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.FloatTensor`:
A scaled input sample.
""" """
step_index = self.index_for_timestep(timestep) step_index = self.index_for_timestep(timestep)
...@@ -196,13 +201,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -196,13 +201,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: Optional[int] = None, num_train_timesteps: Optional[int] = None,
): ):
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, *optional*):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
...@@ -307,17 +312,25 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -307,17 +312,25 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
Args: Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): Args:
current instance of sample being created by diffusion process. model_output (`torch.FloatTensor`):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a
returning a tuple, the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
step_index = self.index_for_timestep(timestep) step_index = self.index_for_timestep(timestep)
......
...@@ -70,36 +70,35 @@ def betas_for_alpha_bar( ...@@ -70,36 +70,35 @@ def betas_for_alpha_bar(
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: KDPM2DiscreteScheduler is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating the Design Space of
https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper.
Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the num_train_timesteps (`int`, defaults to 1000):
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): The number of diffusion steps to train the model.
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from beta_start (`float`, defaults to 0.00085):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.012):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`. `linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, *optional*):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, prediction_type (`str`, defaults to `epsilon`, *optional*):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
prediction_type (`str`, default `epsilon`, optional): `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion Video](https://imagen.research.google/video/paper.pdf) paper).
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 timestep_spacing (`str`, defaults to `"linspace"`):
https://imagen.research.google/video/paper.pdf) The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
timestep_spacing (`str`, default `"linspace"`): Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample steps_offset (`int`, defaults to 0):
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. An offset added to the inference steps. You can use a combination of `offset=1` and
steps_offset (`int`, default `0`): `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
an offset added to the inference steps. You can use a combination of `offset=1` and Diffusion.
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -171,12 +170,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -171,12 +170,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.FloatTensor`:
A scaled input sample.
""" """
step_index = self.index_for_timestep(timestep) step_index = self.index_for_timestep(timestep)
...@@ -195,13 +200,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -195,13 +200,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: Optional[int] = None, num_train_timesteps: Optional[int] = None,
): ):
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, *optional*):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
...@@ -295,17 +300,23 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -295,17 +300,23 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
Args: Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): Args:
current instance of sample being created by diffusion process. model_output (`torch.FloatTensor`):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
returning a tuple, the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
step_index = self.index_for_timestep(timestep) step_index = self.index_for_timestep(timestep)
......
...@@ -47,34 +47,32 @@ class KarrasVeOutput(BaseOutput): ...@@ -47,34 +47,32 @@ class KarrasVeOutput(BaseOutput):
class KarrasVeScheduler(SchedulerMixin, ConfigMixin): class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
""" """
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and A stochastic scheduler tailored to variance-expanding models.
the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic methods the library implements for all schedulers such as loading and saving.
differential equations." https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` <Tip>
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of For more details on the parameters, see [Appendix E](https://arxiv.org/abs/2206.00364). The grid search values used
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of the paper.
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
Args: </Tip>
sigma_min (`float`): minimum noise magnitude
sigma_max (`float`): maximum noise magnitude
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
A reasonable range is [1.000, 1.011].
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
A reasonable range is [0, 100].
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
Args:
sigma_min (`float`, defaults to 0.02):
The minimum noise magnitude.
sigma_max (`float`, defaults to 100):
The maximum noise magnitude.
s_noise (`float`, defaults to 1.007):
The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
1.011].
s_churn (`float`, defaults to 80):
The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100].
s_min (`float`, defaults to 0.05):
The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10].
s_max (`float`, defaults to 50):
The end value of the sigma range to add noise. A reasonable range is [0.2, 80].
""" """
order = 2 order = 2
...@@ -103,22 +101,26 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -103,22 +101,26 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): input sample sample (`torch.FloatTensor`):
timestep (`int`, optional): current timestep The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.FloatTensor`:
A scaled input sample.
""" """
return sample return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
...@@ -136,10 +138,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -136,10 +138,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[torch.FloatTensor, float]: ) -> Tuple[torch.FloatTensor, float]:
""" """
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`.
TODO Args: Args:
sample (`torch.FloatTensor`):
The input sample.
sigma (`float`):
generator (`torch.Generator`, *optional*):
A random number generator.
""" """
if self.config.s_min <= sigma <= self.config.s_max: if self.config.s_min <= sigma <= self.config.s_max:
gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
...@@ -162,21 +169,22 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -162,21 +169,22 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]: ) -> Union[KarrasVeOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
sigma_hat (`float`): TODO The direct output from learned diffusion model.
sigma_prev (`float`): TODO sigma_hat (`float`):
sample_hat (`torch.FloatTensor`): TODO sigma_prev (`float`):
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class sample_hat (`torch.FloatTensor`):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`.
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
Returns: Returns:
[`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When If return_dict is `True`, [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] is returned,
returning a tuple, the first element is the sample tensor. otherwise a tuple is returned where the first element is the sample tensor.
""" """
...@@ -202,16 +210,18 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -202,16 +210,18 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]: ) -> Union[KarrasVeOutput, Tuple]:
""" """
Correct the predicted sample based on the output model_output of the network. TODO complete description Corrects the predicted sample based on the `model_output` of the network.
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
sigma_hat (`float`): TODO sigma_hat (`float`): TODO
sigma_prev (`float`): TODO sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor`): TODO sample_hat (`torch.FloatTensor`): TODO
sample_prev (`torch.FloatTensor`): TODO sample_prev (`torch.FloatTensor`): TODO
derivative (`torch.FloatTensor`): TODO derivative (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
Returns: Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
......
...@@ -29,14 +29,14 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin ...@@ -29,14 +29,14 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
class LMSDiscreteSchedulerOutput(BaseOutput): class LMSDiscreteSchedulerOutput(BaseOutput):
""" """
Output class for the scheduler's step function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
...@@ -91,39 +91,37 @@ def betas_for_alpha_bar( ...@@ -91,39 +91,37 @@ def betas_for_alpha_bar(
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by A linear multistep scheduler for discrete beta schedules.
Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. methods the library implements for all schedulers such as loading and saving.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`, defaults to 1000):
beta_start (`float`): the starting `beta` value of inference. The number of diffusion steps to train the model.
beta_end (`float`): the final `beta` value. beta_start (`float`, defaults to 0.0001):
beta_schedule (`str`): The starting `beta` value of inference.
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`. `linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, *optional*):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`): use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence the sigmas are determined according to a sequence of noise levels {σi}.
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. prediction_type (`str`, defaults to `epsilon`, *optional*):
prediction_type (`str`, default `epsilon`, optional): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 Video](https://imagen.research.google/video/paper.pdf) paper).
https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, defaults to `"linspace"`):
timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, defaults to 0):
steps_offset (`int`, default `0`): An offset added to the inference steps. You can use a combination of `offset=1` and
an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in Diffusion.
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -183,14 +181,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -183,14 +181,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args: Args:
sample (`torch.FloatTensor`): input sample sample (`torch.FloatTensor`):
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain The input sample.
timestep (`float` or `torch.FloatTensor`):
The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.FloatTensor`:
A scaled input sample.
""" """
if isinstance(timestep, torch.Tensor): if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device) timestep = timestep.to(self.timesteps.device)
...@@ -202,12 +204,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -202,12 +204,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def get_lms_coefficient(self, order, t, current_order): def get_lms_coefficient(self, order, t, current_order):
""" """
Compute a linear multistep coefficient. Compute the linear multistep coefficient.
Args: Args:
order (TODO): order ():
t (TODO): t ():
current_order (TODO): current_order ():
""" """
def lms_derivative(tau): def lms_derivative(tau):
...@@ -224,13 +226,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -224,13 +226,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, *optional*):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
...@@ -322,21 +324,25 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -322,21 +324,25 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[LMSDiscreteSchedulerOutput, Tuple]: ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
timestep (`float`): current timestep in the diffusion chain. The direct output from learned diffusion model.
timestep (`float` or `torch.FloatTensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
order: coefficient for multi-step inference. order (`int`, defaults to 4):
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class The order of the linear multistep method.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns: Returns:
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
When returning a tuple, the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
if not self.is_scale_input_called: if not self.is_scale_input_called:
......
...@@ -71,42 +71,42 @@ def betas_for_alpha_bar( ...@@ -71,42 +71,42 @@ def betas_for_alpha_bar(
class PNDMScheduler(SchedulerMixin, ConfigMixin): class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, `PNDMScheduler` uses pseudo numerical methods for diffusion models such as the Runge-Kutta and linear multi-step
namely Runge-Kutta method and a linear multi-step method. method.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. methods the library implements for all schedulers such as loading and saving.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`, defaults to 1000):
beta_start (`float`): the starting `beta` value of inference. The number of diffusion steps to train the model.
beta_end (`float`): the final `beta` value. beta_start (`float`, defaults to 0.0001):
beta_schedule (`str`): The starting `beta` value of inference.
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, *optional*):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
skip_prk_steps (`bool`): skip_prk_steps (`bool`, defaults to `False`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required Allows the scheduler to skip the Runge-Kutta steps defined in the original paper as being required before
before plms steps; defaults to `False`. PLMS steps.
set_alpha_to_one (`bool`, default `False`): set_alpha_to_one (`bool`, defaults to `False`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0. otherwise it uses the alpha value at step 0.
prediction_type (`str`, default `epsilon`, optional): prediction_type (`str`, defaults to `epsilon`, *optional*):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
timestep_spacing (`str`, default `"leading"`): paper).
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample timestep_spacing (`str`, defaults to `"leading"`):
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
steps_offset (`int`, default `0`): Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
an offset added to the inference steps. You can use a combination of `offset=1` and steps_offset (`int`, defaults to 0):
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in An offset added to the inference steps. You can use a combination of `offset=1` and
stable diffusion. `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -169,11 +169,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -169,11 +169,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
...@@ -233,22 +235,24 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -233,22 +235,24 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise), and calls [`~PNDMScheduler.step_prk`]
or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`.
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
timestep (`int`): current discrete timestep in the diffusion chain. The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
returning a tuple, the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
...@@ -264,19 +268,24 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -264,19 +268,24 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
solution to the differential equation. the Runge-Kutta method. It performs four forward passes to approximate the solution to the differential
equation.
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
timestep (`int`): current discrete timestep in the diffusion chain. The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns: Returns:
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
...@@ -319,19 +328,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -319,19 +328,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
times to approximate the solution. the linear multistep method. It performs one forward pass multiple times to approximate the solution.
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
timestep (`int`): current discrete timestep in the diffusion chain. The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns: Returns:
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
...@@ -384,10 +397,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -384,10 +397,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): input sample sample (`torch.FloatTensor`):
The input sample.
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.FloatTensor`:
A scaled input sample.
""" """
return sample return sample
......
...@@ -89,32 +89,28 @@ def betas_for_alpha_bar( ...@@ -89,32 +89,28 @@ def betas_for_alpha_bar(
class RePaintScheduler(SchedulerMixin, ConfigMixin): class RePaintScheduler(SchedulerMixin, ConfigMixin):
""" """
RePaint is a schedule for DDPM inpainting inside a given mask. `RePaintScheduler` is a scheduler for DDPM inpainting inside a given mask.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. methods the library implements for all schedulers such as loading and saving.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`, defaults to 1000):
beta_start (`float`): the starting `beta` value of inference. The number of diffusion steps to train the model.
beta_end (`float`): the final `beta` value. beta_start (`float`, defaults to 0.0001):
beta_schedule (`str`): The starting `beta` value of inference.
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from beta_end (`float`, defaults to 0.02):
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`.
eta (`float`): eta (`float`):
The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and The weight of noise for added noise in diffusion step. If its value is between 0.0 and 1.0 it corresponds
1.0 is DDPM scheduler respectively. to the DDIM scheduler, and if its value is between -0.0 and 1.0 it corresponds to the DDPM scheduler.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, *optional*):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
variance_type (`str`): clip_sample (`bool`, defaults to `True`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, Clip the predicted sample between -1 and 1 for numerical stability.
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
""" """
...@@ -171,11 +167,14 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -171,11 +167,14 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): input sample sample (`torch.FloatTensor`):
timestep (`int`, optional): current timestep The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.FloatTensor`:
A scaled input sample.
""" """
return sample return sample
...@@ -186,6 +185,23 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -186,6 +185,23 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
jump_n_sample: int = 10, jump_n_sample: int = 10,
device: Union[str, torch.device] = None, device: Union[str, torch.device] = None,
): ):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
jump_length (`int`, defaults to 10):
The number of steps taken forward in time before going backward in time for a single jump (“j” in
RePaint paper). Take a look at Figure 9 and 10 in the paper.
jump_n_sample (`int`, defaults to 10):
The number of times to make a forward time jump for a given chosen time sample. Take a look at Figure 9
and 10 in the paper.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
...@@ -239,27 +255,29 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -239,27 +255,29 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[RePaintSchedulerOutput, Tuple]: ) -> Union[RePaintSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned model_output (`torch.FloatTensor`):
diffusion model. The direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
original_image (`torch.FloatTensor`): original_image (`torch.FloatTensor`):
the original image to inpaint on. The original image to inpaint on.
mask (`torch.FloatTensor`): mask (`torch.FloatTensor`):
the mask where 0.0 values define which part of the original image to inpaint (change). The mask where a value of 0.0 indicates which part of the original image to inpaint.
generator (`torch.Generator`, *optional*): random number generator. generator (`torch.Generator`, *optional*):
return_dict (`bool`): option for returning tuple rather than A random number generator.
DDPMSchedulerOutput class return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] or `tuple`.
Returns: Returns:
[`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When If return_dict is `True`, [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] is returned,
returning a tuple, the first element is the sample tensor. otherwise a tuple is returned where the first element is the sample tensor.
""" """
t = timestep t = timestep
......
...@@ -28,14 +28,14 @@ from .scheduling_utils import SchedulerMixin, SchedulerOutput ...@@ -28,14 +28,14 @@ from .scheduling_utils import SchedulerMixin, SchedulerOutput
@dataclass @dataclass
class SdeVeOutput(BaseOutput): class SdeVeOutput(BaseOutput):
""" """
Output class for the ScoreSdeVeScheduler's step function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. Mean averaged `prev_sample` over previous timesteps.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.FloatTensor
...@@ -44,26 +44,25 @@ class SdeVeOutput(BaseOutput): ...@@ -44,26 +44,25 @@ class SdeVeOutput(BaseOutput):
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
""" """
The variance exploding stochastic differential equation (SDE) scheduler. `ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler.
For more information, see the original paper: https://arxiv.org/abs/2011.13456 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`, defaults to 1000):
snr (`float`): The number of diffusion steps to train the model.
coefficient weighting the step from the model_output sample (from the network) to the random noise. snr (`float`, defaults to 0.15):
sigma_min (`float`): A coefficient weighting the step from the `model_output` sample (from the network) to the random noise.
initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the sigma_min (`float`, defaults to 0.01):
distribution of the data. The initial noise scale for the sigma sequence in the sampling procedure. The minimum sigma should mirror
sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. the distribution of the data.
sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to sigma_max (`float`, defaults to 1348.0):
epsilon. The maximum value used for the range of continuous timesteps passed into the model.
correct_steps (`int`): number of correction steps performed on a produced sample. sampling_eps (`float`, defaults to 1e-5):
The end value of sampling where timesteps decrease progressively from 1 to epsilon.
correct_steps (`int`, defaults to 1):
The number of correction steps performed on a produced sample.
""" """
order = 1 order = 1
...@@ -92,11 +91,14 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -92,11 +91,14 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): input sample sample (`torch.FloatTensor`):
timestep (`int`, optional): current timestep The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.FloatTensor`:
A scaled input sample.
""" """
return sample return sample
...@@ -104,13 +106,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -104,13 +106,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
): ):
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): sampling_eps (`float`, *optional*):
final timestep value (overrides value given at Scheduler instantiation). The final timestep value (overrides value given during scheduler instantiation).
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
...@@ -121,19 +125,18 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,19 +125,18 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
): ):
""" """
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. Sets the noise scales used for the diffusion chain (to be run before inference). The sigmas control the weight
of the `drift` and `diffusion` components of the sample update.
The sigmas control the weight of the `drift` and `diffusion` components of sample update.
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional): sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation). The initial noise scale value (overrides value given during scheduler instantiation).
sigma_max (`float`, optional): sigma_max (`float`, optional):
final noise scale value (overrides value given at Scheduler instantiation). The final noise scale value (overrides value given during scheduler instantiation).
sampling_eps (`float`, optional): sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation). The final timestep value (overrides value given during scheduler instantiation).
""" """
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
...@@ -162,20 +165,25 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -162,20 +165,25 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[SdeVeOutput, Tuple]: ) -> Union[SdeVeOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
timestep (`int`): current discrete timestep in the diffusion chain. The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
generator: random number generator. generator (`torch.Generator`, *optional*):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
Returns: Returns:
[`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
is returned where the first element is the sample tensor.
""" """
if self.timesteps is None: if self.timesteps is None:
...@@ -224,19 +232,23 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -224,19 +232,23 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly Correct the predicted sample based on the `model_output` of the network. This is often run repeatedly after
after making the prediction for the previous timestep. making the prediction for the previous timestep.
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
generator: random number generator. generator (`torch.Generator`, *optional*):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
Returns: Returns:
[`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
is returned where the first element is the sample tensor.
""" """
if self.timesteps is None: if self.timesteps is None:
......
...@@ -26,17 +26,18 @@ from .scheduling_utils import SchedulerMixin ...@@ -26,17 +26,18 @@ from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
""" """
The variance preserving stochastic differential equation (SDE) scheduler. `ScoreSdeVpScheduler` is a variance preserving stochastic differential equation (SDE) scheduler.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. methods the library implements for all schedulers such as loading and saving.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions. Args:
num_train_timesteps (`int`, defaults to 2000):
For more information, see the original paper: https://arxiv.org/abs/2011.13456 The number of diffusion steps to train the model.
beta_min (`int`, defaults to 0.1):
UNDER CONSTRUCTION beta_max (`int`, defaults to 20):
sampling_eps (`int`, defaults to 1e-3):
The end value of sampling where timesteps decrease progressively from 1 to epsilon.
""" """
order = 1 order = 1
...@@ -48,9 +49,29 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -48,9 +49,29 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = None self.timesteps = None
def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
"""
Sets the continuous timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
def step_pred(self, score, x, t, generator=None): def step_pred(self, score, x, t, generator=None):
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
score ():
x ():
t ():
generator (`torch.Generator`, *optional*):
A random number generator.
"""
if self.timesteps is None: if self.timesteps is None:
raise ValueError( raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
......
...@@ -28,14 +28,14 @@ from .scheduling_utils import SchedulerMixin ...@@ -28,14 +28,14 @@ from .scheduling_utils import SchedulerMixin
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP
class UnCLIPSchedulerOutput(BaseOutput): class UnCLIPSchedulerOutput(BaseOutput):
""" """
Output class for the scheduler's step function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
......
...@@ -56,78 +56,62 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -56,78 +56,62 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is
by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can
also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied
after any off-the-shelf solvers to increase the order of accuracy.
For more details, see the original paper: https://arxiv.org/abs/2302.04867 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend
to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note
that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`, defaults to 1000):
beta_start (`float`): the starting `beta` value of inference. The number of diffusion steps to train the model.
beta_end (`float`): the final `beta` value. beta_start (`float`, defaults to 0.0001):
beta_schedule (`str`): The starting `beta` value of inference.
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, *optional*):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, default `2`): solver_order (`int`, default `2`):
the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling, due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
and `solver_order=3` for unconditional sampling. unconditional sampling.
prediction_type (`str`, default `epsilon`, optional): prediction_type (`str`, defaults to `epsilon`, *optional*):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
https://imagen.research.google/video/paper.pdf) Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, default `False`): thresholding (`bool`, defaults to `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the as Stable Diffusion.
dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models dynamic_thresholding_ratio (`float`, defaults to 0.995):
(such as stable-diffusion). The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
dynamic_thresholding_ratio (`float`, default `0.995`): sample_max_value (`float`, defaults to 1.0):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
(https://arxiv.org/abs/2205.11487). predict_x0 (`bool`, defaults to `True`):
sample_max_value (`float`, default `1.0`): Whether to use the updating algorithm on the predicted x0.
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, default `True`):
whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details
solver_type (`str`, default `bh2`): solver_type (`str`, default `bh2`):
the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2` Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise. otherwise.
lower_order_final (`bool`, default `True`): lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
disable_corrector (`list`, default `[]`): disable_corrector (`list`, default `[]`):
decide which step to disable the corrector. For large guidance scale, the misalignment between the Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
`epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
by disable the corrector at the first few steps (e.g., disable_corrector=[0]) usually disabled during the first few steps.
solver_p (`SchedulerMixin`, default `None`): solver_p (`SchedulerMixin`, default `None`):
can be any other scheduler. If specified, the algorithm will become solver_p + UniC. Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`): use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence the sigmas are determined according to a sequence of noise levels {σi}.
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, defaults to `"linspace"`):
timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, defaults to 0):
steps_offset (`int`, default `0`): An offset added to the inference steps. You can use a combination of `offset=1` and
an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in Diffusion.
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -200,13 +184,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -200,13 +184,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, *optional*):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
...@@ -298,16 +282,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -298,16 +282,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r""" r"""
Convert the model output to the corresponding type that the algorithm PC needs. Convert the model output to the corresponding type the UniPC algorithm needs.
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
timestep (`int`): current discrete timestep in the diffusion chain. The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: the converted model output. `torch.FloatTensor`:
The converted model output.
""" """
if self.predict_x0: if self.predict_x0:
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
...@@ -357,14 +344,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -357,14 +344,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.FloatTensor`):
direct outputs from learned diffusion model at the current timestep. The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`): previous discrete timestep in the diffusion chain. prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
order (`int`): the order of UniP at this step, also the p in UniPC-p. order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns: Returns:
`torch.FloatTensor`: the sample tensor at the previous timestep. `torch.FloatTensor`:
The sample tensor at the previous timestep.
""" """
timestep_list = self.timestep_list timestep_list = self.timestep_list
model_output_list = self.model_outputs model_output_list = self.model_outputs
...@@ -462,15 +452,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -462,15 +452,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
One step for the UniC (B(h) version). One step for the UniC (B(h) version).
Args: Args:
this_model_output (`torch.FloatTensor`): the model outputs at `x_t` this_model_output (`torch.FloatTensor`):
this_timestep (`int`): the current timestep `t` The model outputs at `x_t`.
last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}` this_timestep (`int`):
this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}` The current timestep `t`.
order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy last_sample (`torch.FloatTensor`):
should be order + 1 The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.FloatTensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns: Returns:
`torch.FloatTensor`: the corrected sample tensor at the current timestep. `torch.FloatTensor`:
The corrected sample tensor at the current timestep.
""" """
timestep_list = self.timestep_list timestep_list = self.timestep_list
model_output_list = self.model_outputs model_output_list = self.model_outputs
...@@ -564,18 +559,23 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -564,18 +559,23 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the multistep UniPC. Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep UniPC.
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`):
timestep (`int`): current discrete timestep in the diffusion chain. The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns: Returns:
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
""" """
...@@ -646,10 +646,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -646,10 +646,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): input sample sample (`torch.FloatTensor`):
The input sample.
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.FloatTensor`:
A scaled input sample.
""" """
return sample return sample
......
...@@ -49,11 +49,11 @@ class KarrasDiffusionSchedulers(Enum): ...@@ -49,11 +49,11 @@ class KarrasDiffusionSchedulers(Enum):
@dataclass @dataclass
class SchedulerOutput(BaseOutput): class SchedulerOutput(BaseOutput):
""" """
Base class for the scheduler's step function output. Base class for the output of a scheduler's `step` function.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
""" """
...@@ -62,11 +62,17 @@ class SchedulerOutput(BaseOutput): ...@@ -62,11 +62,17 @@ class SchedulerOutput(BaseOutput):
class SchedulerMixin: class SchedulerMixin:
""" """
Mixin containing common functions for the schedulers. Base class for all schedulers.
[`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving
functionalities.
[`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to
the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`.
Class attributes: Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that - **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler
`from_config` can be used from a class different than the one used to save the config (should be overridden class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden
by parent class). by parent class).
""" """
...@@ -83,56 +89,50 @@ class SchedulerMixin: ...@@ -83,56 +89,50 @@ class SchedulerMixin:
**kwargs, **kwargs,
): ):
r""" r"""
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository.
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either: Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
organization name, like `google/ddpm-celebahq-256`. the Hub.
- A path to a *directory* containing the schedluer configurations saved using - A path to a *directory* (for example `./my_model_directory`) containing the scheduler
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. configuration saved with [`~SchedulerMixin.save_pretrained`].
subfolder (`str`, *optional*): subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in The subfolder location of a model file within a larger model repository on the Hub or locally.
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`): return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not. Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*): cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
standard cache should not be used. is not used.
force_download (`bool`, *optional*, defaults to `False`): force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist. cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`): resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
file exists. incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*): proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`): output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`): local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model). Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
when running `transformers-cli login` (stored in `~/.huggingface`). `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any allowed by Git.
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip> <Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
use this method in a firewalled environment. `huggingface-cli login`. You can also activate the special
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
firewalled environment.
</Tip> </Tip>
...@@ -148,12 +148,12 @@ class SchedulerMixin: ...@@ -148,12 +148,12 @@ class SchedulerMixin:
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
""" """
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the Save a scheduler configuration object to a directory so that it can be reloaded using the
[`~SchedulerMixin.from_pretrained`] class method. [`~SchedulerMixin.from_pretrained`] class method.
Args: Args:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist). Directory to save a configuration JSON file to. Will be created if it doesn't exist.
""" """
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
......
...@@ -105,36 +105,24 @@ def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamm ...@@ -105,36 +105,24 @@ def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamm
class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
""" """
The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image. A scheduler for vector quantized diffusion.
The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
diffusion timestep. methods the library implements for all schedulers such as loading and saving.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2111.14822
Args: Args:
num_vec_classes (`int`): num_vec_classes (`int`):
The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked
latent pixel. latent pixel.
num_train_timesteps (`int`, defaults to 100):
num_train_timesteps (`int`): The number of diffusion steps to train the model.
Number of diffusion steps used to train the model. alpha_cum_start (`float`, defaults to 0.99999):
alpha_cum_start (`float`):
The starting cumulative alpha value. The starting cumulative alpha value.
alpha_cum_end (`float`, defaults to 0.00009):
alpha_cum_end (`float`):
The ending cumulative alpha value. The ending cumulative alpha value.
gamma_cum_start (`float`, defaults to 0.00009):
gamma_cum_start (`float`):
The starting cumulative gamma value. The starting cumulative gamma value.
gamma_cum_end (`float`, defaults to 0.99999):
gamma_cum_end (`float`):
The ending cumulative gamma value. The ending cumulative gamma value.
""" """
...@@ -189,14 +177,14 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,14 +177,14 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
device (`str` or `torch.device`): The device to which the timesteps and diffusion process parameters (alpha, beta, gamma) should be moved
device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on. to.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
...@@ -218,30 +206,27 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -218,30 +206,27 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[VQDiffusionSchedulerOutput, Tuple]: ) -> Union[VQDiffusionSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the Predict the sample from the previous timestep by the reverse transition distribution. See
docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed. [`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer.
Args: Args:
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked. prediction for the masked class as the initial unnoised image cannot be masked.
t (`torch.long`): t (`torch.long`):
The timestep that determines which transition matrices are used. The timestep that determines which transition matrices are used.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t`.
The classes of each latent pixel at time `t` generator (`torch.Generator`, or `None`):
A random number generator for the noise applied to `p(x_{t-1} | x_t)` before it is sampled from.
generator: (`torch.Generator` or None): return_dict (`bool`, *optional*, defaults to `True`):
RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from. Whether or not to return a [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] or
`tuple`.
return_dict (`bool`):
option for returning tuple rather than VQDiffusionSchedulerOutput class
Returns: Returns:
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. If return_dict is `True`, [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] is
When returning a tuple, the first element is the sample tensor. returned, otherwise a tuple is returned where the first element is the sample tensor.
""" """
if timestep == 0: if timestep == 0:
log_p_x_t_min_1 = model_output log_p_x_t_min_1 = model_output
...@@ -259,32 +244,24 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -259,32 +244,24 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
def q_posterior(self, log_p_x_0, x_t, t): def q_posterior(self, log_p_x_0, x_t, t):
""" """
Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). Calculates the log probabilities for the predicted classes of the image at timestep `t-1`:
Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only
forward probabilities.
Equation (11) stated in terms of forward probabilities via Equation (5):
Where:
- the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0)
```
p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) )
```
Args: Args:
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): log_p_x_0 (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked. prediction for the masked class as the initial unnoised image cannot be masked.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t`.
The classes of each latent pixel at time `t` t (`torch.Long`):
t (torch.Long):
The timestep that determines which transition matrix is used. The timestep that determines which transition matrix is used.
Returns: Returns:
`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). The log probabilities for the predicted classes of the image at timestep `t-1`.
""" """
log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
...@@ -380,25 +357,19 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -380,25 +357,19 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
): ):
""" """
Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
latent pixel in `x_t`. latent pixel in `x_t`.
See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix
is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs.
Args: Args:
t (torch.Long): t (`torch.Long`):
The timestep that determines which transition matrix is used. The timestep that determines which transition matrix is used.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`. The classes of each latent pixel at time `t`.
log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
The log one-hot vectors of `x_t` The log one-hot vectors of `x_t`.
cumulative (`bool`): cumulative (`bool`):
If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`, If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is
we use the cumulative transition matrix `0`->`t`. `True`, the cumulative transition matrix `0`->`t` is used.
Returns: Returns:
`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
......
...@@ -31,14 +31,14 @@ import check_copies # noqa: E402 ...@@ -31,14 +31,14 @@ import check_copies # noqa: E402
# This is the reference code that will be used in the tests. # This is the reference code that will be used in the tests.
# If DDPMSchedulerOutput is changed in scheduling_ddpm.py, this code needs to be manually updated. # If DDPMSchedulerOutput is changed in scheduling_ddpm.py, this code needs to be manually updated.
REFERENCE_CODE = """ \""" REFERENCE_CODE = """ \"""
Output class for the scheduler's step function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
\""" \"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment