Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
......@@ -32,15 +32,15 @@ class SdeVeOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample_mean (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample` over previous timesteps.
"""
prev_sample: torch.FloatTensor
prev_sample_mean: torch.FloatTensor
prev_sample: torch.Tensor
prev_sample_mean: torch.Tensor
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
......@@ -86,19 +86,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......@@ -159,9 +159,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def step_pred(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SdeVeOutput, Tuple]:
......@@ -170,11 +170,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
......@@ -227,8 +227,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def step_correct(
self,
model_output: torch.FloatTensor,
sample: torch.FloatTensor,
model_output: torch.Tensor,
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
......@@ -237,9 +237,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
making the prediction for the previous timestep.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
......@@ -282,10 +282,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
timesteps = timesteps.to(original_samples.device)
sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
......
......@@ -37,15 +37,15 @@ class TCDSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_noised_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_noised_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted noised sample `(x_{s})` based on the model output from the current timestep.
"""
prev_sample: torch.FloatTensor
pred_noised_sample: Optional[torch.FloatTensor] = None
prev_sample: torch.Tensor
pred_noised_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
......@@ -94,17 +94,17 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
......@@ -297,19 +297,19 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......@@ -326,7 +326,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -524,9 +524,9 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
eta: float = 0.3,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
......@@ -536,11 +536,11 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every
......@@ -631,10 +631,10 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
......@@ -656,9 +656,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
......
......@@ -32,16 +32,16 @@ class UnCLIPSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
......@@ -146,17 +146,17 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
sample (`torch.Tensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
`torch.Tensor`: scaled input sample
"""
return sample
......@@ -215,9 +215,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
prev_timestep: Optional[int] = None,
generator=None,
return_dict: bool = True,
......@@ -227,9 +227,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
model_output (`torch.Tensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at.
Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used.
......@@ -327,10 +327,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
......
......@@ -78,11 +78,11 @@ def rescale_zero_terminal_snr(betas):
Args:
betas (`torch.FloatTensor`):
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
......@@ -360,7 +360,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
......@@ -425,7 +425,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
......@@ -452,24 +452,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def convert_model_output(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
r"""
Convert the model output to the corresponding type the UniPC algorithm needs.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
......@@ -522,27 +522,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_uni_p_bh_update(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
*args,
sample: torch.FloatTensor = None,
sample: torch.Tensor = None,
order: int = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
......@@ -651,30 +651,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_uni_c_bh_update(
self,
this_model_output: torch.FloatTensor,
this_model_output: torch.Tensor,
*args,
last_sample: torch.FloatTensor = None,
this_sample: torch.FloatTensor = None,
last_sample: torch.Tensor = None,
this_sample: torch.Tensor = None,
order: int = None,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.FloatTensor`):
this_model_output (`torch.Tensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.FloatTensor`):
last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.FloatTensor`):
this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The corrected sample tensor at the current timestep.
"""
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
......@@ -821,9 +821,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: int,
sample: torch.FloatTensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
......@@ -831,11 +831,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
the multistep UniPC.
Args:
model_output (`torch.FloatTensor`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
......@@ -900,17 +900,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
A scaled input sample.
"""
return sample
......@@ -918,10 +918,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
......@@ -63,12 +63,12 @@ class SchedulerOutput(BaseOutput):
Base class for the output of a scheduler's `step` function.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
prev_sample: torch.Tensor
class SchedulerMixin(PushToHubMixin):
......
......@@ -38,7 +38,7 @@ class VQDiffusionSchedulerOutput(BaseOutput):
prev_sample: torch.LongTensor
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor:
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.Tensor:
"""
Convert batch of vector of class indices into batch of log onehot vectors
......@@ -50,7 +50,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
number of classes to be used for the onehot vectors
Returns:
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
`torch.Tensor` of shape `(batch size, num classes, vector length)`:
Log onehot vectors
"""
x_onehot = F.one_hot(x, num_classes)
......@@ -59,7 +59,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
return log_x
def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor:
def gumbel_noised(logits: torch.Tensor, generator: Optional[torch.Generator]) -> torch.Tensor:
"""
Apply gumbel noise to `logits`
"""
......@@ -199,7 +199,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
model_output: torch.Tensor,
timestep: torch.long,
sample: torch.LongTensor,
generator: Optional[torch.Generator] = None,
......@@ -210,7 +210,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
[`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer.
Args:
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
log_p_x_0: (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked.
t (`torch.long`):
......@@ -251,7 +251,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
```
Args:
log_p_x_0 (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
log_p_x_0 (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
......@@ -260,7 +260,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
The timestep that determines which transition matrix is used.
Returns:
`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`:
The log probabilities for the predicted classes of the image at timestep `t-1`.
"""
log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
......@@ -354,7 +354,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
return log_p_x_t_min_1
def log_Q_t_transitioning_to_known_class(
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.Tensor, cumulative: bool
):
"""
Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
......@@ -365,14 +365,14 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
The timestep that determines which transition matrix is used.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`.
log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
log_onehot_x_t (`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`):
The log one-hot vectors of `x_t`.
cumulative (`bool`):
If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is
`True`, the cumulative transition matrix `0`->`t` is used.
Returns:
`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`:
Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
transition matrix.
......
......@@ -32,16 +32,16 @@ REFERENCE_CODE = """ \"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
\"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
"""
......
......@@ -1041,7 +1041,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_intermediate_state(self):
number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
......
......@@ -472,7 +472,7 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_img2img_intermediate_state(self):
number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
......
......@@ -353,7 +353,7 @@ class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_pix2pix_intermediate_state(self):
number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
......
......@@ -416,7 +416,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_text2img_intermediate_state(self):
number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
......
......@@ -461,7 +461,7 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_depth2img_intermediate_state(self):
number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
......
......@@ -475,7 +475,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
def test_callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
......
......@@ -213,7 +213,7 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
def test_stable_diffusion_img_variation_intermediate_state(self):
number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
......
......@@ -349,7 +349,7 @@ class StableDiffusionPanoramaNightlyTests(unittest.TestCase):
def test_stable_diffusion_panorama_intermediate_state(self):
number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment