Unverified Commit d769d8a1 authored by David El Malih's avatar David El Malih Committed by GitHub
Browse files

Improve docstrings and type hints in scheduling_heun_discrete.py (#12726)

refactor: improve type hints for `beta_schedule`, `prediction_type`, and `timestep_spacing` parameters, and add return type hints to several methods.
parent c25582d5
...@@ -107,12 +107,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -107,12 +107,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference. The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02): beta_end (`float`, defaults to 0.02):
The final `beta` value. The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): beta_schedule (`"linear"`, `"scaled_linear"`, `"squaredcos_cap_v2"`, or `"exp"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`. `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `exp`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper). Video](https://huggingface.co/papers/2210.02303) paper).
...@@ -128,7 +128,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -128,7 +128,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`): use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): steps_offset (`int`, defaults to 0):
...@@ -144,17 +144,17 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -144,17 +144,17 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012, beta_end: float = 0.012,
beta_schedule: str = "linear", beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "exp"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False,
clip_sample: Optional[bool] = False, clip_sample: Optional[bool] = False,
clip_sample_range: float = 1.0, clip_sample_range: float = 1.0,
timestep_spacing: str = "linspace", timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ) -> None:
if self.config.use_beta_sigmas and not is_scipy_available(): if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
...@@ -241,7 +241,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -241,7 +241,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return self._begin_index return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0): def set_begin_index(self, begin_index: int = 0) -> None:
""" """
Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
...@@ -263,7 +263,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -263,7 +263,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
sample (`torch.Tensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
...@@ -283,19 +283,19 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -283,19 +283,19 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
device: Union[str, torch.device] = None, device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None, num_train_timesteps: Optional[int] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
): ) -> None:
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`, *optional*, defaults to `None`):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str`, `torch.device`, *optional*, defaults to `None`):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
num_train_timesteps (`int`, *optional*): num_train_timesteps (`int`, *optional*, defaults to `None`):
The number of diffusion steps used when training the model. If `None`, the default The number of diffusion steps used when training the model. If `None`, the default
`num_train_timesteps` attribute is used. `num_train_timesteps` attribute is used.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*, defaults to `None`):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
must be `None`, and `timestep_spacing` attribute will be ignored. must be `None`, and `timestep_spacing` attribute will be ignored.
...@@ -370,7 +370,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -370,7 +370,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
""" """
Convert sigma values to corresponding timestep values through interpolation. Convert sigma values to corresponding timestep values through interpolation.
...@@ -407,7 +407,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -407,7 +407,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
""" """
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364). Models](https://huggingface.co/papers/2206.00364).
...@@ -700,5 +700,5 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -700,5 +700,5 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def __len__(self): def __len__(self) -> int:
return self.config.num_train_timesteps return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment