Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
...@@ -32,15 +32,15 @@ class SdeVeOutput(BaseOutput): ...@@ -32,15 +32,15 @@ class SdeVeOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample_mean (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample` over previous timesteps. Mean averaged `prev_sample` over previous timesteps.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
prev_sample_mean: torch.FloatTensor prev_sample_mean: torch.Tensor
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
...@@ -86,19 +86,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -86,19 +86,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -159,9 +159,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -159,9 +159,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def step_pred( def step_pred(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SdeVeOutput, Tuple]: ) -> Union[SdeVeOutput, Tuple]:
...@@ -170,11 +170,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -170,11 +170,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -227,8 +227,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -227,8 +227,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def step_correct( def step_correct(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
sample: torch.FloatTensor, sample: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
...@@ -237,9 +237,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -237,9 +237,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
making the prediction for the previous timestep. making the prediction for the previous timestep.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -282,10 +282,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -282,10 +282,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps] sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
......
...@@ -37,15 +37,15 @@ class TCDSchedulerOutput(BaseOutput): ...@@ -37,15 +37,15 @@ class TCDSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_noised_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_noised_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted noised sample `(x_{s})` based on the model output from the current timestep. The predicted noised sample `(x_{s})` based on the model output from the current timestep.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_noised_sample: Optional[torch.FloatTensor] = None pred_noised_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -94,17 +94,17 @@ def betas_for_alpha_bar( ...@@ -94,17 +94,17 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor: def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
""" """
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -297,19 +297,19 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -297,19 +297,19 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -326,7 +326,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -326,7 +326,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -524,9 +524,9 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -524,9 +524,9 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
eta: float = 0.3, eta: float = 0.3,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
...@@ -536,11 +536,11 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -536,11 +536,11 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
eta (`float`): eta (`float`):
A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every
...@@ -631,10 +631,10 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -631,10 +631,10 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
...@@ -656,9 +656,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -656,9 +656,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity( def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
......
...@@ -32,16 +32,16 @@ class UnCLIPSchedulerOutput(BaseOutput): ...@@ -32,16 +32,16 @@ class UnCLIPSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -146,17 +146,17 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -146,17 +146,17 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): input sample sample (`torch.Tensor`): input sample
timestep (`int`, optional): current timestep timestep (`int`, optional): current timestep
Returns: Returns:
`torch.FloatTensor`: scaled input sample `torch.Tensor`: scaled input sample
""" """
return sample return sample
...@@ -215,9 +215,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -215,9 +215,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
prev_timestep: Optional[int] = None, prev_timestep: Optional[int] = None,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
...@@ -227,9 +227,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -227,9 +227,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.Tensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at. prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at.
Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used. Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used.
...@@ -327,10 +327,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -327,10 +327,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
......
...@@ -78,11 +78,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -78,11 +78,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -360,7 +360,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -360,7 +360,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -425,7 +425,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -425,7 +425,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break # Hack to make sure that other schedulers which copy this function don't break
...@@ -452,24 +452,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -452,24 +452,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output( def convert_model_output(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
r""" r"""
Convert the model output to the corresponding type the UniPC algorithm needs. Convert the model output to the corresponding type the UniPC algorithm needs.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
...@@ -522,27 +522,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -522,27 +522,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_uni_p_bh_update( def multistep_uni_p_bh_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
*args, *args,
sample: torch.FloatTensor = None, sample: torch.Tensor = None,
order: int = None, order: int = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep. The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`): prev_timestep (`int`):
The previous discrete timestep in the diffusion chain. The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
order (`int`): order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p). The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
...@@ -651,30 +651,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -651,30 +651,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_uni_c_bh_update( def multistep_uni_c_bh_update(
self, self,
this_model_output: torch.FloatTensor, this_model_output: torch.Tensor,
*args, *args,
last_sample: torch.FloatTensor = None, last_sample: torch.Tensor = None,
this_sample: torch.FloatTensor = None, this_sample: torch.Tensor = None,
order: int = None, order: int = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
One step for the UniC (B(h) version). One step for the UniC (B(h) version).
Args: Args:
this_model_output (`torch.FloatTensor`): this_model_output (`torch.Tensor`):
The model outputs at `x_t`. The model outputs at `x_t`.
this_timestep (`int`): this_timestep (`int`):
The current timestep `t`. The current timestep `t`.
last_sample (`torch.FloatTensor`): last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`. The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.FloatTensor`): this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`. The generated sample after the last predictor `x_{t}`.
order (`int`): order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The corrected sample tensor at the current timestep. The corrected sample tensor at the current timestep.
""" """
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
...@@ -821,9 +821,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -821,9 +821,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -831,11 +831,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -831,11 +831,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
the multistep UniPC. the multistep UniPC.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`int`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
...@@ -900,17 +900,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -900,17 +900,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -918,10 +918,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -918,10 +918,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -63,12 +63,12 @@ class SchedulerOutput(BaseOutput): ...@@ -63,12 +63,12 @@ class SchedulerOutput(BaseOutput):
Base class for the output of a scheduler's `step` function. Base class for the output of a scheduler's `step` function.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
class SchedulerMixin(PushToHubMixin): class SchedulerMixin(PushToHubMixin):
......
...@@ -38,7 +38,7 @@ class VQDiffusionSchedulerOutput(BaseOutput): ...@@ -38,7 +38,7 @@ class VQDiffusionSchedulerOutput(BaseOutput):
prev_sample: torch.LongTensor prev_sample: torch.LongTensor
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor: def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.Tensor:
""" """
Convert batch of vector of class indices into batch of log onehot vectors Convert batch of vector of class indices into batch of log onehot vectors
...@@ -50,7 +50,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen ...@@ -50,7 +50,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
number of classes to be used for the onehot vectors number of classes to be used for the onehot vectors
Returns: Returns:
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`: `torch.Tensor` of shape `(batch size, num classes, vector length)`:
Log onehot vectors Log onehot vectors
""" """
x_onehot = F.one_hot(x, num_classes) x_onehot = F.one_hot(x, num_classes)
...@@ -59,7 +59,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen ...@@ -59,7 +59,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
return log_x return log_x
def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor: def gumbel_noised(logits: torch.Tensor, generator: Optional[torch.Generator]) -> torch.Tensor:
""" """
Apply gumbel noise to `logits` Apply gumbel noise to `logits`
""" """
...@@ -199,7 +199,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -199,7 +199,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: torch.long, timestep: torch.long,
sample: torch.LongTensor, sample: torch.LongTensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
...@@ -210,7 +210,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -210,7 +210,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
[`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer. [`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer.
Args: Args:
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): log_p_x_0: (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked. prediction for the masked class as the initial unnoised image cannot be masked.
t (`torch.long`): t (`torch.long`):
...@@ -251,7 +251,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -251,7 +251,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
``` ```
Args: Args:
log_p_x_0 (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): log_p_x_0 (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked. prediction for the masked class as the initial unnoised image cannot be masked.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
...@@ -260,7 +260,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -260,7 +260,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
The timestep that determines which transition matrix is used. The timestep that determines which transition matrix is used.
Returns: Returns:
`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: `torch.Tensor` of shape `(batch size, num classes, num latent pixels)`:
The log probabilities for the predicted classes of the image at timestep `t-1`. The log probabilities for the predicted classes of the image at timestep `t-1`.
""" """
log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
...@@ -354,7 +354,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -354,7 +354,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
return log_p_x_t_min_1 return log_p_x_t_min_1
def log_Q_t_transitioning_to_known_class( def log_Q_t_transitioning_to_known_class(
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.Tensor, cumulative: bool
): ):
""" """
Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
...@@ -365,14 +365,14 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -365,14 +365,14 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
The timestep that determines which transition matrix is used. The timestep that determines which transition matrix is used.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`. The classes of each latent pixel at time `t`.
log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): log_onehot_x_t (`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`):
The log one-hot vectors of `x_t`. The log one-hot vectors of `x_t`.
cumulative (`bool`): cumulative (`bool`):
If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is
`True`, the cumulative transition matrix `0`->`t` is used. `True`, the cumulative transition matrix `0`->`t` is used.
Returns: Returns:
`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: `torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`:
Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
transition matrix. transition matrix.
......
...@@ -32,16 +32,16 @@ REFERENCE_CODE = """ \""" ...@@ -32,16 +32,16 @@ REFERENCE_CODE = """ \"""
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
\""" \"""
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
""" """
......
...@@ -1041,7 +1041,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -1041,7 +1041,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_intermediate_state(self): def test_stable_diffusion_intermediate_state(self):
number_of_steps = 0 number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True callback_fn.has_been_called = True
nonlocal number_of_steps nonlocal number_of_steps
number_of_steps += 1 number_of_steps += 1
......
...@@ -472,7 +472,7 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -472,7 +472,7 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_img2img_intermediate_state(self): def test_stable_diffusion_img2img_intermediate_state(self):
number_of_steps = 0 number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True callback_fn.has_been_called = True
nonlocal number_of_steps nonlocal number_of_steps
number_of_steps += 1 number_of_steps += 1
......
...@@ -353,7 +353,7 @@ class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase): ...@@ -353,7 +353,7 @@ class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_pix2pix_intermediate_state(self): def test_stable_diffusion_pix2pix_intermediate_state(self):
number_of_steps = 0 number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True callback_fn.has_been_called = True
nonlocal number_of_steps nonlocal number_of_steps
number_of_steps += 1 number_of_steps += 1
......
...@@ -416,7 +416,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase): ...@@ -416,7 +416,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_text2img_intermediate_state(self): def test_stable_diffusion_text2img_intermediate_state(self):
number_of_steps = 0 number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True callback_fn.has_been_called = True
nonlocal number_of_steps nonlocal number_of_steps
number_of_steps += 1 number_of_steps += 1
......
...@@ -461,7 +461,7 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase): ...@@ -461,7 +461,7 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_depth2img_intermediate_state(self): def test_stable_diffusion_depth2img_intermediate_state(self):
number_of_steps = 0 number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True callback_fn.has_been_called = True
nonlocal number_of_steps nonlocal number_of_steps
number_of_steps += 1 number_of_steps += 1
......
...@@ -475,7 +475,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase): ...@@ -475,7 +475,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_text2img_intermediate_state_v_pred(self): def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
number_of_steps = 0 number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: def test_callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
test_callback_fn.has_been_called = True test_callback_fn.has_been_called = True
nonlocal number_of_steps nonlocal number_of_steps
number_of_steps += 1 number_of_steps += 1
......
...@@ -213,7 +213,7 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase): ...@@ -213,7 +213,7 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_img_variation_intermediate_state(self): def test_stable_diffusion_img_variation_intermediate_state(self):
number_of_steps = 0 number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True callback_fn.has_been_called = True
nonlocal number_of_steps nonlocal number_of_steps
number_of_steps += 1 number_of_steps += 1
......
...@@ -349,7 +349,7 @@ class StableDiffusionPanoramaNightlyTests(unittest.TestCase): ...@@ -349,7 +349,7 @@ class StableDiffusionPanoramaNightlyTests(unittest.TestCase):
def test_stable_diffusion_panorama_intermediate_state(self): def test_stable_diffusion_panorama_intermediate_state(self):
number_of_steps = 0 number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True callback_fn.has_been_called = True
nonlocal number_of_steps nonlocal number_of_steps
number_of_steps += 1 number_of_steps += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment