Unverified Commit e6110f68 authored by Nathan Lambert's avatar Nathan Lambert Committed by GitHub
Browse files

[docs sprint] schedulers docs, will update (#376)



* init schedulers docs

* add some docstrings, fix sidebar formatting

* add docstrings

* [Type hint] PNDM schedulers (#335)

* [Type hint] PNDM Schedulers

* ran make style

* updated timesteps type hint

* apply suggestions from code review

* ran make style

* removed unused import

* [Type hint] scheduling ddim (#343)

* [Type hint] scheduling ddim

* apply suggestions from code review

apply suggestions to also return the return type
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* make style

* update class docstrings

* add docstrings

* missed merge edit

* add general docs page

* modify headings for right sidebar
Co-authored-by: default avatarPartho <parthodas6176@gmail.com>
Co-authored-by: default avatarSantiago Víquez <santi.viquez@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent cee3aa0d
...@@ -10,19 +10,95 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o ...@@ -10,19 +10,95 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License. specific language governing permissions and limitations under the License.
--> -->
# Models # Schedulers
Diffusers contains multiple pre-built schedule functions for the diffusion process.
## What is a schduler?
The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample.
- Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs.
- adding noise in different manners represent the algorithmic processes to train a diffusion model by adding noise to images.
- for inference, the scheduler defines how to update a sample based on an output from a pretrained model.
- Schedulers are often defined by a *noise schedule* and an *update rule* to solve the differential equation solution.
### Discrete versus continuous schedulers
All schedulers take in a timestep to predict the updated version of the sample being diffused.
The timesteps dictate where in the diffusion process the step is, where data is generated by iterating forward in time and inference is executed by propagating backwards through timesteps.
Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting 'float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`].
## Designing Re-usable schedulers
The core design principle between the schedule functions is to be model, system, and framework independent.
This allows for rapid experimentation and cleaner abstractions in the code, where the model prediction is separated from the sample update.
To this end, the design of schedulers is such that:
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
## API ## API
The core API for any new scheduler must follow a limited structure.
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
- Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
### Core
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
#### SchedulerMixin
[[autodoc]] SchedulerMixin
#### SchedulerOutput
The class [`SchedulerOutput`] contains the ouputs from any schedulers `step(...)` call.
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
### Existing Schedulers
#### Denoising diffusion implicit models (DDIM)
Original paper can be found here.
[[autodoc]] schedulers.scheduling_ddim.DDIMScheduler
#### Denoising diffusion probabilistic models (DDPM)
Original paper can be found [here](https://arxiv.org/abs/2010.02502).
[[autodoc]] schedulers.scheduling_ddpm.DDPMScheduler
#### Varience exploding, stochastic sampling from Karras et. al
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeScheduler
#### Linear multistep scheduler for discrete beta schedules
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
[[autodoc]] schedulers.scheduling_lms_discrete.LMSDiscreteScheduler
#### Pseudo numerical methods for diffusion models (PNDM)
Original implementation can be found [here](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).
[[autodoc]] schedulers.scheduling_pndm.PNDMScheduler
#### variance exploding stochastic differential equation (SDE) scheduler
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
[[autodoc]] schedulers.scheduling_sde_ve.ScoreSdeVeScheduler
#### variance preserving stochastic differential equation (SDE) scheduler
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
<Tip warning={true}>
Models should provide the `def forward` function and initialization of the model. Score SDE-VP is under construction.
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
## Examples </Tip>
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3. [[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
- TODO: mention VAE / SDE score estimation
\ No newline at end of file
...@@ -30,11 +30,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -30,11 +30,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1]. (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
from 0 to 1 and to that part of the diffusion process.
produces the cumulative product of (1-beta) up to that part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
...@@ -49,6 +55,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -49,6 +55,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDIMScheduler(SchedulerMixin, ConfigMixin): class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): TODO
timestep_values (`np.ndarray`, optional): TODO
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -62,7 +91,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -62,7 +91,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
tensor_format: str = "pt", tensor_format: str = "pt",
): ):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
...@@ -101,6 +131,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -101,6 +131,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps(self, num_inference_steps: int, offset: int = 0): def set_timesteps(self, num_inference_steps: int, offset: int = 0):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`): TODO
"""
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange( self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
...@@ -118,7 +156,24 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -118,7 +156,24 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
"""
if self.num_inference_steps is None: if self.num_inference_steps is None:
raise ValueError( raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
......
...@@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1]. (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
from 0 to 1 and to that part of the diffusion process.
produces the cumulative product of (1-beta) up to that part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
...@@ -48,6 +54,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -48,6 +54,29 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDPMScheduler(SchedulerMixin, ConfigMixin): class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): TODO
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -88,6 +117,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -88,6 +117,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
def set_timesteps(self, num_inference_steps: int): def set_timesteps(self, num_inference_steps: int):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange( self.timesteps = np.arange(
...@@ -137,7 +173,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,7 +173,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
"""
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
......
...@@ -49,6 +49,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -49,6 +49,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456 differential equations." https://arxiv.org/abs/2011.13456
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
Args:
sigma_min (`float`): minimum noise magnitude
sigma_max (`float`): maximum noise magnitude
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
A reasonable range is [1.000, 1.011].
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
A reasonable range is [0, 100].
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
""" """
@register_to_config @register_to_config
...@@ -62,23 +80,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -62,23 +80,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_max: float = 50, s_max: float = 50,
tensor_format: str = "pt", tensor_format: str = "pt",
): ):
"""
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
Args:
sigma_min (`float`): minimum noise magnitude
sigma_max (`float`): maximum noise magnitude
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
A reasonable range is [1.000, 1.011].
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
A reasonable range is [0, 100].
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
"""
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = None self.timesteps = None
...@@ -88,6 +89,14 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -88,6 +89,14 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int): def set_timesteps(self, num_inference_steps: int):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [ self.schedule = [
...@@ -104,6 +113,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -104,6 +113,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
""" """
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
TODO Args:
""" """
if self.s_min <= sigma <= self.s_max: if self.s_min <= sigma <= self.s_max:
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
...@@ -125,6 +136,21 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -125,6 +136,21 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat: Union[torch.FloatTensor, np.ndarray], sample_hat: Union[torch.FloatTensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]: ) -> Union[KarrasVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
"""
pred_original_sample = sample_hat + sigma_hat * model_output pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat derivative = (sample_hat - pred_original_sample) / sigma_hat
...@@ -145,7 +171,22 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -145,7 +171,22 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
derivative: Union[torch.FloatTensor, np.ndarray], derivative: Union[torch.FloatTensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]: ) -> Union[KarrasVeOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. TODO complete description
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
"""
pred_original_sample = sample_prev + sigma_prev * model_output pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
......
...@@ -24,6 +24,26 @@ from .scheduling_utils import SchedulerMixin, SchedulerOutput ...@@ -24,6 +24,26 @@ from .scheduling_utils import SchedulerMixin, SchedulerOutput
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional): TODO
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
timestep_values (`np.ndarry`, optional): TODO
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -35,12 +55,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -35,12 +55,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep_values: Optional[np.ndarray] = None, timestep_values: Optional[np.ndarray] = None,
tensor_format: str = "pt", tensor_format: str = "pt",
): ):
""" if trained_betas is not None:
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by self.betas = np.asarray(trained_betas)
Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
"""
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
...@@ -64,7 +80,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -64,7 +80,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def get_lms_coefficient(self, order, t, current_order): def get_lms_coefficient(self, order, t, current_order):
""" """
Compute a linear multistep coefficient Compute a linear multistep coefficient.
Args:
order (TODO):
t (TODO):
current_order (TODO):
""" """
def lms_derivative(tau): def lms_derivative(tau):
...@@ -80,6 +101,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -80,6 +101,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return integrated_coeff return integrated_coeff
def set_timesteps(self, num_inference_steps: int): def set_timesteps(self, num_inference_steps: int):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
...@@ -102,6 +130,22 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,6 +130,22 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
order: int = 4, order: int = 4,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
"""
sigma = self.sigmas[timestep] sigma = self.sigmas[timestep]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
from typing import Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1]. (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
from 0 to 1 and to that part of the diffusion process.
produces the cumulative product of (1-beta) up to that part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
...@@ -48,6 +54,27 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -48,6 +54,27 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class PNDMScheduler(SchedulerMixin, ConfigMixin): class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): TODO
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -55,10 +82,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -55,10 +82,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
tensor_format: str = "pt", tensor_format: str = "pt",
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
): ):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
...@@ -98,6 +127,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -98,6 +127,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`): TODO
"""
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self._timesteps = list( self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
...@@ -135,7 +172,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -135,7 +172,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
"""
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
else: else:
...@@ -151,6 +204,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -151,6 +204,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation. solution to the differential equation.
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
raise ValueError( raise ValueError(
...@@ -194,6 +258,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -194,6 +258,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution. times to approximate the solution.
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
raise ValueError( raise ValueError(
......
...@@ -47,12 +47,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -47,12 +47,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
""" """
The variance exploding stochastic differential equation (SDE) scheduler. The variance exploding stochastic differential equation (SDE) scheduler.
:param snr: coefficient weighting the step from the model_output sample (from the network) to the random noise. For more information, see the original paper: https://arxiv.org/abs/2011.13456
:param sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data. Args:
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to snr (`float`):
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format: coefficient weighting the step from the model_output sample (from the network) to the random noise.
"np" or "pt" for the expected format of samples passed to the Scheduler. sigma_min (`float`):
initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data.
sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to
epsilon.
correct_steps (`int`): number of correction steps performed on a produced sample.
tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
""" """
@register_to_config @register_to_config
...@@ -66,11 +73,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -66,11 +73,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps=1, correct_steps=1,
tensor_format="pt", tensor_format="pt",
): ):
# self.sigmas = None # setable values
# self.discrete_sigmas = None
#
# # setable values
# self.num_inference_steps = None
self.timesteps = None self.timesteps = None
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
...@@ -79,6 +82,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -79,6 +82,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps, sampling_eps=None): def set_timesteps(self, num_inference_steps, sampling_eps=None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": if tensor_format == "np":
...@@ -89,6 +101,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -89,6 +101,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None): def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
"""
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
The sigmas control the weight of the `drift` and `diffusion` components of sample update.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
"""
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
...@@ -140,7 +166,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -140,7 +166,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
**kwargs, **kwargs,
) -> Union[SdeVeOutput, Tuple]: ) -> Union[SdeVeOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
""" """
if "seed" in kwargs and kwargs["seed"] is not None: if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"]) self.set_seed(kwargs["seed"])
...@@ -186,6 +225,17 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -186,6 +225,17 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
""" """
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep. after making the prediction for the previous timestep.
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
""" """
if "seed" in kwargs and kwargs["seed"] is not None: if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"]) self.set_seed(kwargs["seed"])
......
...@@ -24,6 +24,15 @@ from .scheduling_utils import SchedulerMixin ...@@ -24,6 +24,15 @@ from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
The variance preserving stochastic differential equation (SDE) scheduler.
For more information, see the original paper: https://arxiv.org/abs/2011.13456
UNDER CONSTRUCTION
"""
@register_to_config @register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
......
...@@ -38,6 +38,9 @@ class SchedulerOutput(BaseOutput): ...@@ -38,6 +38,9 @@ class SchedulerOutput(BaseOutput):
class SchedulerMixin: class SchedulerMixin:
"""
Mixin containing common functions for the schedulers.
"""
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
ignore_for_config = ["tensor_format"] ignore_for_config = ["tensor_format"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment