Unverified Commit 859b8090 authored by David El Malih's avatar David El Malih Committed by GitHub
Browse files

Improve docstrings and type hints in scheduling_euler_ancestral_discrete.py (#12766)

refactor: add type hints to methods and update docstrings for parameters.
parent d769d8a1
...@@ -94,7 +94,7 @@ def betas_for_alpha_bar( ...@@ -94,7 +94,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas): def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
""" """
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
...@@ -144,16 +144,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -144,16 +144,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference. The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02): beta_end (`float`, defaults to 0.02):
The final `beta` value. The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`. `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper). Video](https://huggingface.co/papers/2210.02303) paper).
timestep_spacing (`str`, defaults to `"linspace"`): timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): steps_offset (`int`, defaults to 0):
...@@ -173,13 +173,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -173,13 +173,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
timestep_spacing: str = "linspace", timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
): ) -> None:
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -219,7 +219,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -219,7 +219,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
def init_noise_sigma(self): def init_noise_sigma(self) -> torch.Tensor:
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]: if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max() return self.sigmas.max()
...@@ -227,21 +227,21 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -227,21 +227,21 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (self.sigmas.max() ** 2 + 1) ** 0.5
@property @property
def step_index(self): def step_index(self) -> Optional[int]:
""" """
The index counter for current timestep. It will increase 1 after each scheduler step. The index counter for current timestep. It will increase 1 after each scheduler step.
""" """
return self._step_index return self._step_index
@property @property
def begin_index(self): def begin_index(self) -> Optional[int]:
""" """
The index for the first timestep. It should be set from pipeline with `set_begin_index` method. The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
""" """
return self._begin_index return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0): def set_begin_index(self, begin_index: int = 0) -> None:
""" """
Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
...@@ -259,7 +259,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -259,7 +259,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
sample (`torch.Tensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
...@@ -275,7 +275,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -275,7 +275,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True self.is_scale_input_called = True
return sample return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -381,13 +381,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -381,13 +381,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.Tensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
return_dict (`bool`): return_dict (`bool`, defaults to `True`):
Whether or not to return a Whether or not to return a
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
...@@ -517,5 +517,5 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -517,5 +517,5 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def __len__(self): def __len__(self) -> int:
return self.config.num_train_timesteps return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment