Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
...@@ -33,12 +33,12 @@ class DDPMWuerstchenSchedulerOutput(BaseOutput): ...@@ -33,12 +33,12 @@ class DDPMWuerstchenSchedulerOutput(BaseOutput):
Output class for the scheduler's step function output. Output class for the scheduler's step function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
def betas_for_alpha_bar( def betas_for_alpha_bar(
...@@ -125,17 +125,17 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin): ...@@ -125,17 +125,17 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
) ** 2 / self._init_alpha_cumprod.to(device) ) ** 2 / self._init_alpha_cumprod.to(device)
return alpha_cumprod.clamp(0.0001, 0.9999) return alpha_cumprod.clamp(0.0001, 0.9999)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): input sample sample (`torch.Tensor`): input sample
timestep (`int`, optional): current timestep timestep (`int`, optional): current timestep
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.Tensor`: scaled input sample
""" """
return sample return sample
...@@ -163,9 +163,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin): ...@@ -163,9 +163,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]: ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]:
...@@ -174,9 +174,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin): ...@@ -174,9 +174,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.Tensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class
...@@ -209,10 +209,10 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin): ...@@ -209,10 +209,10 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
device = original_samples.device device = original_samples.device
dtype = original_samples.dtype dtype = original_samples.dtype
alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view( alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
......
...@@ -276,7 +276,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -276,7 +276,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -341,7 +341,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -341,7 +341,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -368,24 +368,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -368,24 +368,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Convert the model output to the corresponding type the DEIS algorithm needs. Convert the model output to the corresponding type the DEIS algorithm needs.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -425,26 +425,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -425,26 +425,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def deis_first_order_update( def deis_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the first-order DEIS (equivalent to DDIM). One step for the first-order DEIS (equivalent to DDIM).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
prev_timestep (`int`): prev_timestep (`int`):
The previous discrete timestep in the diffusion chain. The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -483,22 +483,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -483,22 +483,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_deis_second_order_update( def multistep_deis_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the second-order multistep DEIS. One step for the second-order multistep DEIS.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
...@@ -552,22 +552,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -552,22 +552,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_deis_third_order_update( def multistep_deis_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the third-order multistep DEIS. One step for the third-order multistep DEIS.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
...@@ -673,9 +673,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -673,9 +673,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -683,11 +683,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -683,11 +683,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the multistep DEIS. the multistep DEIS.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
...@@ -736,17 +736,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -736,17 +736,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -754,10 +754,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -754,10 +754,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -78,11 +78,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -78,11 +78,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -408,7 +408,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -408,7 +408,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -472,7 +472,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -472,7 +472,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -497,7 +497,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -497,7 +497,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas return sigmas
def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Lu et al. (2022).""" """Constructs the noise schedule of Lu et al. (2022)."""
lambda_min: float = in_lambdas[-1].item() lambda_min: float = in_lambdas[-1].item()
...@@ -512,11 +512,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -512,11 +512,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
...@@ -530,13 +530,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -530,13 +530,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
</Tip> </Tip>
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -611,23 +611,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -611,23 +611,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the first-order DPMSolver (equivalent to DDIM). One step for the first-order DPMSolver (equivalent to DDIM).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -680,23 +680,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -680,23 +680,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the second-order multistep DPMSolver. One step for the second-order multistep DPMSolver.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
...@@ -803,22 +803,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -803,22 +803,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update( def multistep_dpm_solver_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the third-order multistep DPMSolver. One step for the third-order multistep DPMSolver.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
...@@ -919,11 +919,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -919,11 +919,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator=None, generator=None,
variance_noise: Optional[torch.FloatTensor] = None, variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -931,15 +931,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -931,15 +931,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the multistep DPMSolver. the multistep DPMSolver.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
variance_noise (`torch.FloatTensor`): variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`]. itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`): return_dict (`bool`):
...@@ -1006,27 +1006,27 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -1006,27 +1006,27 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -295,7 +295,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -295,7 +295,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -360,7 +360,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -360,7 +360,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -388,11 +388,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -388,11 +388,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
...@@ -406,13 +406,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -406,13 +406,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
</Tip> </Tip>
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -488,23 +488,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -488,23 +488,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the first-order DPMSolver (equivalent to DDIM). One step for the first-order DPMSolver (equivalent to DDIM).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -558,23 +558,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -558,23 +558,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the second-order multistep DPMSolver. One step for the second-order multistep DPMSolver.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
...@@ -682,22 +682,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -682,22 +682,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
def multistep_dpm_solver_third_order_update( def multistep_dpm_solver_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the third-order multistep DPMSolver. One step for the third-order multistep DPMSolver.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
...@@ -786,11 +786,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -786,11 +786,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator=None, generator=None,
variance_noise: Optional[torch.FloatTensor] = None, variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -798,15 +798,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -798,15 +798,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
the multistep DPMSolver. the multistep DPMSolver.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
variance_noise (`torch.FloatTensor`): variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`]. itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`): return_dict (`bool`):
...@@ -867,27 +867,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -867,27 +867,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -257,21 +257,21 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -257,21 +257,21 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
if self.step_index is None: if self.step_index is None:
...@@ -395,7 +395,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -395,7 +395,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return t return t
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() sigma_min: float = in_sigmas[-1].item()
...@@ -414,9 +414,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -414,9 +414,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
s_noise: float = 1.0, s_noise: float = 1.0,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
...@@ -425,11 +425,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -425,11 +425,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): model_output (`torch.Tensor` or `np.ndarray`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float` or `torch.FloatTensor`): timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.Tensor` or `np.ndarray`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
...@@ -450,10 +450,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -450,10 +450,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
# Define functions to compute sigma and t from each other # Define functions to compute sigma and t from each other
def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor: def sigma_fn(_t: torch.Tensor) -> torch.Tensor:
return _t.neg().exp() return _t.neg().exp()
def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
return _sigma.log().neg() return _sigma.log().neg()
if self.state_in_first_order: if self.state_in_first_order:
...@@ -526,10 +526,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -526,10 +526,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -361,7 +361,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -361,7 +361,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -426,7 +426,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -426,7 +426,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -453,11 +453,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -453,11 +453,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
...@@ -471,13 +471,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -471,13 +471,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
</Tip> </Tip>
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -542,26 +542,26 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -542,26 +542,26 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the first-order DPMSolver (equivalent to DDIM). One step for the first-order DPMSolver (equivalent to DDIM).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
prev_timestep (`int`): prev_timestep (`int`):
The previous discrete timestep in the diffusion chain. The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -598,27 +598,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -598,27 +598,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_second_order_update( def singlestep_dpm_solver_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
time `timestep_list[-2]`. time `timestep_list[-2]`.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`): timestep (`int`):
The current and latter discrete timestep in the diffusion chain. The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`): prev_timestep (`int`):
The previous discrete timestep in the diffusion chain. The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
...@@ -692,27 +692,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -692,27 +692,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_third_order_update( def singlestep_dpm_solver_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
time `timestep_list[-3]`. time `timestep_list[-3]`.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`): timestep (`int`):
The current and latter discrete timestep in the diffusion chain. The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`): prev_timestep (`int`):
The previous discrete timestep in the diffusion chain. The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
...@@ -796,29 +796,29 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -796,29 +796,29 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_update( def singlestep_dpm_solver_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
order: int = None, order: int = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the singlestep DPMSolver. One step for the singlestep DPMSolver.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`): timestep (`int`):
The current and latter discrete timestep in the diffusion chain. The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`): prev_timestep (`int`):
The previous discrete timestep in the diffusion chain. The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
order (`int`): order (`int`):
The solver order at this step. The solver order at this step.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
...@@ -891,9 +891,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -891,9 +891,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -901,11 +901,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -901,11 +901,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
the singlestep DPMSolver. the singlestep DPMSolver.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
...@@ -950,17 +950,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -950,17 +950,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -968,10 +968,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -968,10 +968,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -206,21 +206,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -206,21 +206,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return denoised return denoised
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input( def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
if self.step_index is None: if self.step_index is None:
...@@ -276,7 +274,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -276,7 +274,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min = sigma_min or self.config.sigma_min sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max sigma_max = sigma_max or self.config.sigma_max
...@@ -289,7 +287,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -289,7 +287,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion. """Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
...@@ -300,7 +298,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -300,7 +298,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -365,9 +363,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -365,9 +363,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
...@@ -381,13 +379,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -381,13 +379,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
</Tip> </Tip>
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The converted model output. The converted model output.
""" """
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
...@@ -400,21 +398,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -400,21 +398,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the first-order DPMSolver (equivalent to DDIM). One step for the first-order DPMSolver (equivalent to DDIM).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
...@@ -438,21 +436,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -438,21 +436,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the second-order multistep DPMSolver. One step for the second-order multistep DPMSolver.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
sigma_t, sigma_s0, sigma_s1 = ( sigma_t, sigma_s0, sigma_s1 = (
...@@ -509,20 +507,20 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -509,20 +507,20 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update( def multistep_dpm_solver_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.Tensor],
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the third-order multistep DPMSolver. One step for the third-order multistep DPMSolver.
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
...@@ -596,9 +594,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -596,9 +594,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
...@@ -607,11 +605,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -607,11 +605,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the multistep DPMSolver. the multistep DPMSolver.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -675,10 +673,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -675,10 +673,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -35,16 +35,16 @@ class EDMEulerSchedulerOutput(BaseOutput): ...@@ -35,16 +35,16 @@ class EDMEulerSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
class EDMEulerScheduler(SchedulerMixin, ConfigMixin): class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
...@@ -174,21 +174,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -174,21 +174,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
return denoised return denoised
def scale_model_input( def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
if self.step_index is None: if self.step_index is None:
...@@ -227,7 +225,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -227,7 +225,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min = sigma_min or self.config.sigma_min sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max sigma_max = sigma_max or self.config.sigma_max
...@@ -239,7 +237,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -239,7 +237,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
return sigmas return sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion. """Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
...@@ -275,9 +273,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -275,9 +273,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: torch.FloatTensor, sample: torch.Tensor,
s_churn: float = 0.0, s_churn: float = 0.0,
s_tmin: float = 0.0, s_tmin: float = 0.0,
s_tmax: float = float("inf"), s_tmax: float = float("inf"),
...@@ -290,11 +288,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -290,11 +288,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
s_churn (`float`): s_churn (`float`):
s_tmin (`float`): s_tmin (`float`):
...@@ -375,10 +373,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -375,10 +373,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -35,16 +35,16 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): ...@@ -35,16 +35,16 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -250,21 +250,19 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -250,21 +250,19 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def scale_model_input( def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
...@@ -346,9 +344,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -346,9 +344,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: torch.FloatTensor, sample: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
...@@ -357,11 +355,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -357,11 +355,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -450,10 +448,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -450,10 +448,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -35,16 +35,16 @@ class EulerDiscreteSchedulerOutput(BaseOutput): ...@@ -35,16 +35,16 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -274,21 +274,19 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -274,21 +274,19 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def scale_model_input( def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
if self.step_index is None: if self.step_index is None:
...@@ -445,7 +443,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -445,7 +443,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t return t
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -494,9 +492,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -494,9 +492,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: torch.FloatTensor, sample: torch.Tensor,
s_churn: float = 0.0, s_churn: float = 0.0,
s_tmin: float = 0.0, s_tmin: float = 0.0,
s_tmax: float = float("inf"), s_tmax: float = float("inf"),
...@@ -509,11 +507,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -509,11 +507,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
s_churn (`float`): s_churn (`float`):
s_tmin (`float`): s_tmin (`float`):
...@@ -606,10 +604,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -606,10 +604,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
...@@ -637,9 +635,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -637,9 +635,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def get_velocity( def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor
) -> torch.FloatTensor:
if ( if (
isinstance(timesteps, int) isinstance(timesteps, int)
or isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.IntTensor)
......
...@@ -198,21 +198,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -198,21 +198,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
if self.step_index is None: if self.step_index is None:
...@@ -329,7 +329,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -329,7 +329,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -369,9 +369,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -369,9 +369,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -379,11 +379,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -379,11 +379,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
...@@ -469,10 +469,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -469,10 +469,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -137,9 +137,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,9 +137,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -147,11 +147,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -147,11 +147,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
the linear multistep method. It performs one forward pass multiple times to approximate the solution. the linear multistep method. It performs one forward pass multiple times to approximate the solution.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
...@@ -193,17 +193,17 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -193,17 +193,17 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
......
...@@ -175,21 +175,21 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -175,21 +175,21 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
if self.step_index is None: if self.step_index is None:
...@@ -321,7 +321,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -321,7 +321,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -376,9 +376,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -376,9 +376,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.Tensor, np.ndarray],
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
...@@ -387,11 +387,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -387,11 +387,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -477,10 +477,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -477,10 +477,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -175,21 +175,21 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -175,21 +175,21 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
if self.step_index is None: if self.step_index is None:
...@@ -334,7 +334,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -334,7 +334,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return t return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -361,9 +361,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -361,9 +361,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -371,11 +371,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -371,11 +371,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
...@@ -452,10 +452,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -452,10 +452,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -176,10 +176,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -176,10 +176,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
Args: Args:
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO sigma_hat (`float`): TODO
sigma_prev (`float`): TODO sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_hat (`torch.Tensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns: Returns:
...@@ -213,12 +213,12 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -213,12 +213,12 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
Args: Args:
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO sigma_hat (`float`): TODO
sigma_prev (`float`): TODO sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_hat (`torch.Tensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.Tensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.Tensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns: Returns:
......
...@@ -37,16 +37,16 @@ class LCMSchedulerOutput(BaseOutput): ...@@ -37,16 +37,16 @@ class LCMSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
denoised: Optional[torch.FloatTensor] = None denoised: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -95,17 +95,17 @@ def betas_for_alpha_bar( ...@@ -95,17 +95,17 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor: def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
""" """
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -296,24 +296,24 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -296,24 +296,24 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -497,9 +497,9 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -497,9 +497,9 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[LCMSchedulerOutput, Tuple]: ) -> Union[LCMSchedulerOutput, Tuple]:
...@@ -508,11 +508,11 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -508,11 +508,11 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -594,10 +594,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -594,10 +594,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
...@@ -619,9 +619,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -619,9 +619,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity( def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
......
...@@ -32,16 +32,16 @@ class LMSDiscreteSchedulerOutput(BaseOutput): ...@@ -32,16 +32,16 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -202,21 +202,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -202,21 +202,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def scale_model_input( def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`float` or `torch.FloatTensor`): timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
...@@ -351,7 +349,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -351,7 +349,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t return t
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() sigma_min: float = in_sigmas[-1].item()
...@@ -366,9 +364,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -366,9 +364,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: torch.FloatTensor, sample: torch.Tensor,
order: int = 4, order: int = 4,
return_dict: bool = True, return_dict: bool = True,
) -> Union[LMSDiscreteSchedulerOutput, Tuple]: ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
...@@ -377,11 +375,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -377,11 +375,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float` or `torch.FloatTensor`): timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
order (`int`, defaults to 4): order (`int`, defaults to 4):
The order of the linear multistep method. The order of the linear multistep method.
...@@ -444,10 +442,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -444,10 +442,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -225,9 +225,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -225,9 +225,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -236,11 +236,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -236,11 +236,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`. or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
...@@ -258,9 +258,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -258,9 +258,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_prk( def step_prk(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -269,11 +269,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -269,11 +269,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
equation. equation.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
...@@ -318,9 +318,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -318,9 +318,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_plms( def step_plms(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -328,11 +328,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -328,11 +328,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
the linear multistep method. It performs one forward pass multiple times to approximate the solution. the linear multistep method. It performs one forward pass multiple times to approximate the solution.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
...@@ -387,17 +387,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -387,17 +387,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -448,10 +448,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -448,10 +448,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
......
...@@ -31,16 +31,16 @@ class RePaintSchedulerOutput(BaseOutput): ...@@ -31,16 +31,16 @@ class RePaintSchedulerOutput(BaseOutput):
Output class for the scheduler's step function output. Output class for the scheduler's step function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from The predicted denoised sample (x_{0}) based on the model output from
the current timestep. `pred_original_sample` can be used to preview progress or for guidance. the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: torch.FloatTensor pred_original_sample: torch.Tensor
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -160,19 +160,19 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -160,19 +160,19 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
self.eta = eta self.eta = eta
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -245,11 +245,11 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -245,11 +245,11 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
original_image: torch.FloatTensor, original_image: torch.Tensor,
mask: torch.FloatTensor, mask: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[RePaintSchedulerOutput, Tuple]: ) -> Union[RePaintSchedulerOutput, Tuple]:
...@@ -258,15 +258,15 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -258,15 +258,15 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
original_image (`torch.FloatTensor`): original_image (`torch.Tensor`):
The original image to inpaint on. The original image to inpaint on.
mask (`torch.FloatTensor`): mask (`torch.Tensor`):
The mask where a value of 0.0 indicates which part of the original image to inpaint. The mask where a value of 0.0 indicates which part of the original image to inpaint.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -351,10 +351,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -351,10 +351,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
def __len__(self): def __len__(self):
......
...@@ -305,7 +305,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -305,7 +305,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -370,7 +370,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -370,7 +370,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -397,11 +397,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -397,11 +397,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs. Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs.
Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
...@@ -415,13 +415,13 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -415,13 +415,13 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
</Tip> </Tip>
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -686,29 +686,29 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -686,29 +686,29 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def stochastic_adams_bashforth_update( def stochastic_adams_bashforth_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor, sample: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
order: int, order: int,
tau: torch.FloatTensor, tau: torch.Tensor,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the SA-Predictor. One step for the SA-Predictor.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep. The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`): prev_timestep (`int`):
The previous discrete timestep in the diffusion chain. The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
order (`int`): order (`int`):
The order of SA-Predictor at this timestep. The order of SA-Predictor at this timestep.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
...@@ -813,32 +813,32 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -813,32 +813,32 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def stochastic_adams_moulton_update( def stochastic_adams_moulton_update(
self, self,
this_model_output: torch.FloatTensor, this_model_output: torch.Tensor,
*args, *args,
last_sample: torch.FloatTensor, last_sample: torch.Tensor,
last_noise: torch.FloatTensor, last_noise: torch.Tensor,
this_sample: torch.FloatTensor, this_sample: torch.Tensor,
order: int, order: int,
tau: torch.FloatTensor, tau: torch.Tensor,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the SA-Corrector. One step for the SA-Corrector.
Args: Args:
this_model_output (`torch.FloatTensor`): this_model_output (`torch.Tensor`):
The model outputs at `x_t`. The model outputs at `x_t`.
this_timestep (`int`): this_timestep (`int`):
The current timestep `t`. The current timestep `t`.
last_sample (`torch.FloatTensor`): last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`. The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.FloatTensor`): this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`. The generated sample after the last predictor `x_{t}`.
order (`int`): order (`int`):
The order of SA-Corrector at this step. The order of SA-Corrector at this step.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The corrected sample tensor at the current timestep. The corrected sample tensor at the current timestep.
""" """
...@@ -979,9 +979,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -979,9 +979,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
...@@ -990,11 +990,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -990,11 +990,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
the SA-Solver. the SA-Solver.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -1079,17 +1079,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -1079,17 +1079,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -1097,10 +1097,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -1097,10 +1097,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment