Unverified Commit 15370f84 authored by David El Malih's avatar David El Malih Committed by GitHub
Browse files

Improve docstrings and type hints in scheduling_pndm.py (#12676)

* Enhance docstrings and type hints in PNDMScheduler class

- Updated parameter descriptions to include default values and specific types using Literal for better clarity.
- Improved docstring formatting and consistency across methods, including detailed explanations for the `_get_prev_sample` method.
- Added type hints for method return types to enhance code readability and maintainability.

* Refactor docstring in PNDMScheduler class to enhance clarity

- Simplified the explanation of the method for computing the previous sample from the current sample.
- Updated the reference to the PNDM paper for better accessibility.
- Removed redundant notation explanations to streamline the documentation.
parent a96b1453
...@@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving. methods the library implements for all schedulers such as loading and saving.
Args: Args:
num_train_timesteps (`int`, defaults to 1000): num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model. The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001): beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference. The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02): beta_end (`float`, defaults to `0.02`):
The final `beta` value. The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
skip_prk_steps (`bool`, defaults to `False`): skip_prk_steps (`bool`, defaults to `False`):
...@@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0. otherwise it uses the alpha value at step 0.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
paper). timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families. An offset added to the inference steps, as required by some model families.
""" """
...@@ -117,12 +115,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,12 +115,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
set_alpha_to_one: bool = False, set_alpha_to_one: bool = False,
prediction_type: str = "epsilon", prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
timestep_spacing: str = "leading", timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
...@@ -164,7 +162,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -164,7 +162,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.plms_timesteps = None self.plms_timesteps = None
self.timesteps = None self.timesteps = None
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -243,7 +241,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,7 +241,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns: Returns:
...@@ -276,14 +274,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -276,14 +274,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
raise ValueError( raise ValueError(
...@@ -335,14 +332,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -335,14 +332,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
raise ValueError( raise ValueError(
...@@ -403,19 +399,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -403,19 +399,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): def _get_prev_sample(
# See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778 self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
# this function computes x_(t−δ) using the formula of (9) ) -> torch.Tensor:
# Note that x_t needs to be added to both sides of the equation """
Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM
# Notation (<variable name> -> <name in paper> paper](https://huggingface.co/papers/2202.09778).
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(t−δ) Args:
# beta_prod_t -> (1 - α_t) sample (`torch.Tensor`):
# beta_prod_t_prev -> (1 - α_(t−δ)) The current sample x_t.
# sample -> x_t timestep (`int`):
# model_output -> e_θ(x_t, t) The current timestep t.
# prev_sample -> x_(t−δ) prev_timestep (`int`):
The previous timestep (t-δ).
model_output (`torch.Tensor`):
The model output e_θ(x_t, t).
Returns:
`torch.Tensor`:
The previous sample x_(t-δ).
"""
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
...@@ -489,5 +493,5 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -489,5 +493,5 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self) -> int:
return self.config.num_train_timesteps return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment