Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
......@@ -33,12 +33,12 @@ class DDPMWuerstchenSchedulerOutput(BaseOutput):
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
prev_sample: torch.Tensor
def betas_for_alpha_bar(
......@@ -125,17 +125,17 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
) ** 2 / self._init_alpha_cumprod.to(device)
return alpha_cumprod.clamp(0.0001, 0.9999)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
sample (`torch.Tensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
`torch.Tensor`: scaled input sample
"""
return sample
......@@ -163,9 +163,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]:
......@@ -174,9 +174,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
model_output (`torch.Tensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class
......@@ -209,10 +209,10 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
device = original_samples.device
dtype = original_samples.dtype
alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
......
......@@ -276,7 +276,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -341,7 +341,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -368,24 +368,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DEIS algorithm needs.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -425,26 +425,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def deis_first_order_update(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the first-order DEIS (equivalent to DDIM).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -483,22 +483,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_deis_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the second-order multistep DEIS.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
......@@ -552,22 +552,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_deis_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the third-order multistep DEIS.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
......@@ -673,9 +673,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -683,11 +683,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the multistep DEIS.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
......@@ -736,17 +736,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......@@ -754,10 +754,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -78,11 +78,11 @@ def rescale_zero_terminal_snr(betas):
Args:
betas (`torch.FloatTensor`):
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
......@@ -408,7 +408,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -472,7 +472,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -497,7 +497,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Lu et al. (2022)."""
lambda_min: float = in_lambdas[-1].item()
......@@ -512,11 +512,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
......@@ -530,13 +530,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
</Tip>
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -611,23 +611,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
noise: Optional[torch.FloatTensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -680,23 +680,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
noise: Optional[torch.FloatTensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the second-order multistep DPMSolver.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
......@@ -803,22 +803,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the third-order multistep DPMSolver.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
......@@ -919,11 +919,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -931,15 +931,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the multistep DPMSolver.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.FloatTensor`):
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`):
......@@ -1006,27 +1006,27 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -295,7 +295,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -360,7 +360,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -388,11 +388,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
def convert_model_output(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
......@@ -406,13 +406,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
</Tip>
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -488,23 +488,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
noise: Optional[torch.FloatTensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -558,23 +558,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
noise: Optional[torch.FloatTensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the second-order multistep DPMSolver.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
......@@ -682,22 +682,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the third-order multistep DPMSolver.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
......@@ -786,11 +786,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -798,15 +798,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
the multistep DPMSolver.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.FloatTensor`):
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`):
......@@ -867,27 +867,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -257,21 +257,21 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
sample: torch.Tensor,
timestep: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
......@@ -395,7 +395,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return t
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
......@@ -414,9 +414,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
s_noise: float = 1.0,
) -> Union[SchedulerOutput, Tuple]:
......@@ -425,11 +425,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`):
model_output (`torch.Tensor` or `np.ndarray`):
The direct output from learned diffusion model.
timestep (`float` or `torch.FloatTensor`):
timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.Tensor` or `np.ndarray`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
......@@ -450,10 +450,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
# Define functions to compute sigma and t from each other
def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor:
def sigma_fn(_t: torch.Tensor) -> torch.Tensor:
return _t.neg().exp()
def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor:
def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
return _sigma.log().neg()
if self.state_in_first_order:
......@@ -526,10 +526,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -361,7 +361,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -426,7 +426,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -453,11 +453,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
......@@ -471,13 +471,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
</Tip>
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -542,26 +542,26 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -598,27 +598,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
time `timestep_list[-2]`.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
......@@ -692,27 +692,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
time `timestep_list[-3]`.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
......@@ -796,29 +796,29 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_update(
self,
model_output_list: List[torch.FloatTensor],
model_output_list: List[torch.Tensor],
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
order: int = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the singlestep DPMSolver.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
order (`int`):
The solver order at this step.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
......@@ -891,9 +891,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -901,11 +901,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
the singlestep DPMSolver.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
......@@ -950,17 +950,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......@@ -968,10 +968,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -206,21 +206,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return denoised
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
......@@ -276,7 +274,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
......@@ -289,7 +287,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
......@@ -300,7 +298,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -365,9 +363,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output(
self,
model_output: torch.FloatTensor,
sample: torch.FloatTensor = None,
) -> torch.FloatTensor:
model_output: torch.Tensor,
sample: torch.Tensor = None,
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
......@@ -381,13 +379,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
</Tip>
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The converted model output.
"""
sigma = self.sigmas[self.step_index]
......@@ -400,21 +398,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
sample: torch.FloatTensor = None,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
model_output: torch.Tensor,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
......@@ -438,21 +436,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
sample: torch.FloatTensor = None,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
One step for the second-order multistep DPMSolver.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s0, sigma_s1 = (
......@@ -509,20 +507,20 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
sample: torch.FloatTensor = None,
) -> torch.FloatTensor:
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
) -> torch.Tensor:
"""
One step for the third-order multistep DPMSolver.
Args:
model_output_list (`List[torch.FloatTensor]`):
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
......@@ -596,9 +594,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
......@@ -607,11 +605,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the multistep DPMSolver.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
......@@ -675,10 +673,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -35,16 +35,16 @@ class EDMEulerSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
......@@ -174,21 +174,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
return denoised
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
......@@ -227,7 +225,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
......@@ -239,7 +237,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
return sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
......@@ -275,9 +273,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
......@@ -290,11 +288,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
......@@ -375,10 +373,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -35,16 +35,16 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
......@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
Args:
betas (`torch.FloatTensor`):
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
......@@ -250,21 +250,19 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
......@@ -346,9 +344,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
......@@ -357,11 +355,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
......@@ -450,10 +448,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -35,16 +35,16 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
......@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
Args:
betas (`torch.FloatTensor`):
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
......@@ -274,21 +274,19 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
......@@ -445,7 +443,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -494,9 +492,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
......@@ -509,11 +507,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
......@@ -606,10 +604,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......@@ -637,9 +635,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor
) -> torch.FloatTensor:
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
if (
isinstance(timesteps, int)
or isinstance(timesteps, torch.IntTensor)
......
......@@ -198,21 +198,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
sample: torch.Tensor,
timestep: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
......@@ -329,7 +329,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -369,9 +369,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -379,11 +379,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
......@@ -469,10 +469,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -137,9 +137,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -147,11 +147,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
the linear multistep method. It performs one forward pass multiple times to approximate the solution.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
......@@ -193,17 +193,17 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......
......@@ -175,21 +175,21 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
sample: torch.Tensor,
timestep: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
......@@ -321,7 +321,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -376,9 +376,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
......@@ -387,11 +387,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
......@@ -477,10 +477,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -175,21 +175,21 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
sample: torch.Tensor,
timestep: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
......@@ -334,7 +334,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -361,9 +361,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -371,11 +371,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
......@@ -452,10 +452,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -176,10 +176,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
Args:
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_hat (`torch.Tensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns:
......@@ -213,12 +213,12 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
Args:
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
sample_hat (`torch.Tensor` or `np.ndarray`): TODO
sample_prev (`torch.Tensor` or `np.ndarray`): TODO
derivative (`torch.Tensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns:
......
......@@ -37,16 +37,16 @@ class LCMSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
denoised: Optional[torch.FloatTensor] = None
prev_sample: torch.Tensor
denoised: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
......@@ -95,17 +95,17 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
......@@ -296,24 +296,24 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -497,9 +497,9 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[LCMSchedulerOutput, Tuple]:
......@@ -508,11 +508,11 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
......@@ -594,10 +594,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
......@@ -619,9 +619,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
......
......@@ -32,16 +32,16 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
......@@ -202,21 +202,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`float` or `torch.FloatTensor`):
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
......@@ -351,7 +349,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
......@@ -366,9 +364,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
order: int = 4,
return_dict: bool = True,
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
......@@ -377,11 +375,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float` or `torch.FloatTensor`):
timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`, defaults to 4):
The order of the linear multistep method.
......@@ -444,10 +442,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -225,9 +225,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -236,11 +236,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
......@@ -258,9 +258,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_prk(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -269,11 +269,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
equation.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
......@@ -318,9 +318,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_plms(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -328,11 +328,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
the linear multistep method. It performs one forward pass multiple times to approximate the solution.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
......@@ -387,17 +387,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......@@ -448,10 +448,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
......
......@@ -31,16 +31,16 @@ class RePaintSchedulerOutput(BaseOutput):
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from
the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: torch.FloatTensor
prev_sample: torch.Tensor
pred_original_sample: torch.Tensor
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
......@@ -160,19 +160,19 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
self.eta = eta
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......@@ -245,11 +245,11 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
original_image: torch.FloatTensor,
mask: torch.FloatTensor,
sample: torch.Tensor,
original_image: torch.Tensor,
mask: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[RePaintSchedulerOutput, Tuple]:
......@@ -258,15 +258,15 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
original_image (`torch.FloatTensor`):
original_image (`torch.Tensor`):
The original image to inpaint on.
mask (`torch.FloatTensor`):
mask (`torch.Tensor`):
The mask where a value of 0.0 indicates which part of the original image to inpaint.
generator (`torch.Generator`, *optional*):
A random number generator.
......@@ -351,10 +351,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
def __len__(self):
......
......@@ -305,7 +305,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -370,7 +370,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -397,11 +397,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs.
Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
......@@ -415,13 +415,13 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
</Tip>
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -686,29 +686,29 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def stochastic_adams_bashforth_update(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor,
noise: torch.FloatTensor,
sample: torch.Tensor,
noise: torch.Tensor,
order: int,
tau: torch.FloatTensor,
tau: torch.Tensor,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the SA-Predictor.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of SA-Predictor at this timestep.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
......@@ -813,32 +813,32 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def stochastic_adams_moulton_update(
self,
this_model_output: torch.FloatTensor,
this_model_output: torch.Tensor,
*args,
last_sample: torch.FloatTensor,
last_noise: torch.FloatTensor,
this_sample: torch.FloatTensor,
last_sample: torch.Tensor,
last_noise: torch.Tensor,
this_sample: torch.Tensor,
order: int,
tau: torch.FloatTensor,
tau: torch.Tensor,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the SA-Corrector.
Args:
this_model_output (`torch.FloatTensor`):
this_model_output (`torch.Tensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.FloatTensor`):
last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.FloatTensor`):
this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The order of SA-Corrector at this step.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The corrected sample tensor at the current timestep.
"""
......@@ -979,9 +979,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
......@@ -990,11 +990,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
the SA-Solver.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
......@@ -1079,17 +1079,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......@@ -1097,10 +1097,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment