Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
......@@ -86,10 +86,10 @@ class TemporalDecoder(nn.Module):
def forward(
self,
sample: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
sample: torch.Tensor,
image_only_indicator: torch.Tensor,
num_frames: int = 1,
) -> torch.FloatTensor:
) -> torch.Tensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
......@@ -315,13 +315,13 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
......@@ -341,15 +341,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
z: torch.Tensor,
num_frames: int,
return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]:
) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
......@@ -370,15 +370,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
num_frames: int = 1,
) -> Union[DecoderOutput, torch.FloatTensor]:
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
......
......@@ -157,11 +157,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value
def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
......@@ -194,7 +194,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
"""
self.enable_tiling(False)
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
......@@ -202,10 +202,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
tiles overlap and are blended together to form a smooth output.
Args:
x (`torch.FloatTensor`): Input batch of images.
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.FloatTensor`: Encoded batch of images.
`torch.Tensor`: Encoded batch of images.
"""
# scale of encoder output relative to input
sf = self.spatial_scale_factor
......@@ -242,7 +242,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
return out
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
......@@ -250,10 +250,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
tiles overlap and are blended together to form a smooth output.
Args:
x (`torch.FloatTensor`): Input batch of images.
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.FloatTensor`: Encoded batch of images.
`torch.Tensor`: Encoded batch of images.
"""
# scale of decoder output relative to input
sf = self.spatial_scale_factor
......@@ -290,9 +290,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
return out
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
......@@ -308,8 +306,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
@apply_forward_hook
def decode(
self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
output = torch.cat(output)
......@@ -323,12 +321,12 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
......
......@@ -276,13 +276,13 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
self, x: torch.Tensor, return_dict: bool = True
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
tuple.
......@@ -312,22 +312,22 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
z: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
num_inference_steps: int = 2,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
"""
Decodes the input latent vector `z` using the consistency decoder VAE model.
Args:
z (torch.FloatTensor): The input latent vector.
z (torch.Tensor): The input latent vector.
generator (Optional[torch.Generator]): The random number generator. Default is None.
return_dict (bool): Whether to return the output as a dictionary. Default is True.
num_inference_steps (int): The number of inference steps. Default is 2.
Returns:
Union[DecoderOutput, Tuple[torch.FloatTensor]]: The decoded output.
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
"""
z = (z * self.config.scaling_factor - self.means) / self.stds
......@@ -370,9 +370,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
......@@ -382,7 +380,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
plain tuple.
......@@ -429,14 +427,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
......
......@@ -36,11 +36,11 @@ class DecoderOutput(BaseOutput):
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.FloatTensor
sample: torch.Tensor
class Encoder(nn.Module):
......@@ -136,7 +136,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, sample: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `Encoder` class."""
sample = self.conv_in(sample)
......@@ -282,9 +282,9 @@ class Decoder(nn.Module):
def forward(
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
sample: torch.Tensor,
latent_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
......@@ -367,7 +367,7 @@ class UpSample(nn.Module):
self.out_channels = out_channels
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `UpSample` class."""
x = torch.relu(x)
x = self.deconv(x)
......@@ -416,7 +416,7 @@ class MaskConditionEncoder(nn.Module):
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
r"""The forward method of the `MaskConditionEncoder` class."""
out = {}
for l in range(len(self.layers)):
......@@ -533,11 +533,11 @@ class MaskConditionDecoder(nn.Module):
def forward(
self,
z: torch.FloatTensor,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
z: torch.Tensor,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
latent_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""The forward method of the `MaskConditionDecoder` class."""
sample = z
sample = self.conv_in(sample)
......@@ -711,7 +711,7 @@ class VectorQuantizer(nn.Module):
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
......@@ -730,7 +730,7 @@ class VectorQuantizer(nn.Module):
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q: torch.FloatTensor = z + (z_q - z).detach()
z_q: torch.Tensor = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
......@@ -745,7 +745,7 @@ class VectorQuantizer(nn.Module):
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
......@@ -753,7 +753,7 @@ class VectorQuantizer(nn.Module):
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q: torch.FloatTensor = self.embedding(indices)
z_q: torch.Tensor = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
......@@ -776,7 +776,7 @@ class DiagonalGaussianDistribution(object):
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
......@@ -873,7 +873,7 @@ class EncoderTiny(nn.Module):
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing:
......@@ -956,7 +956,7 @@ class DecoderTiny(nn.Module):
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `DecoderTiny` class."""
# Clamp.
x = torch.tanh(x / 3) * 3
......
......@@ -665,10 +665,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
......@@ -677,18 +677,18 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
......
......@@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import FloatTensor, nn
from torch import Tensor, nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, is_torch_version, logging
......@@ -54,12 +54,12 @@ class ControlNetXSOutput(BaseOutput):
The output of [`UNetControlNetXSModel`].
Args:
sample (`FloatTensor` of shape `(batch_size, num_channels, height, width)`):
sample (`Tensor` of shape `(batch_size, num_channels, height, width)`):
The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base
model output, but is already the final output.
"""
sample: FloatTensor = None
sample: Tensor = None
class DownBlockControlNetXSAdapter(nn.Module):
......@@ -1001,7 +1001,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
def forward(
self,
sample: FloatTensor,
sample: Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: Optional[torch.Tensor] = None,
......@@ -1018,13 +1018,13 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
The [`ControlNetXSModel`] forward method.
Args:
sample (`FloatTensor`):
sample (`Tensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`FloatTensor`):
controlnet_cond (`Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
How much the control model affects the base model outputs.
......@@ -1402,16 +1402,16 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
def forward(
self,
hidden_states_base: FloatTensor,
temb: FloatTensor,
encoder_hidden_states: Optional[FloatTensor] = None,
hidden_states_ctrl: Optional[FloatTensor] = None,
hidden_states_base: Tensor,
temb: Tensor,
encoder_hidden_states: Optional[Tensor] = None,
hidden_states_ctrl: Optional[Tensor] = None,
conditioning_scale: Optional[float] = 1.0,
attention_mask: Optional[FloatTensor] = None,
attention_mask: Optional[Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[FloatTensor] = None,
encoder_attention_mask: Optional[Tensor] = None,
apply_control: bool = True,
) -> Tuple[FloatTensor, FloatTensor, Tuple[FloatTensor, ...], Tuple[FloatTensor, ...]]:
) -> Tuple[Tensor, Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -1626,16 +1626,16 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
def forward(
self,
hidden_states_base: FloatTensor,
temb: FloatTensor,
encoder_hidden_states: FloatTensor,
hidden_states_ctrl: Optional[FloatTensor] = None,
hidden_states_base: Tensor,
temb: Tensor,
encoder_hidden_states: Tensor,
hidden_states_ctrl: Optional[Tensor] = None,
conditioning_scale: Optional[float] = 1.0,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[FloatTensor] = None,
encoder_attention_mask: Optional[FloatTensor] = None,
attention_mask: Optional[Tensor] = None,
encoder_attention_mask: Optional[Tensor] = None,
apply_control: bool = True,
) -> Tuple[FloatTensor, FloatTensor]:
) -> Tuple[Tensor, Tensor]:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -1807,18 +1807,18 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
def forward(
self,
hidden_states: FloatTensor,
res_hidden_states_tuple_base: Tuple[FloatTensor, ...],
res_hidden_states_tuple_ctrl: Tuple[FloatTensor, ...],
temb: FloatTensor,
encoder_hidden_states: Optional[FloatTensor] = None,
hidden_states: Tensor,
res_hidden_states_tuple_base: Tuple[Tensor, ...],
res_hidden_states_tuple_ctrl: Tuple[Tensor, ...],
temb: Tensor,
encoder_hidden_states: Optional[Tensor] = None,
conditioning_scale: Optional[float] = 1.0,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[FloatTensor] = None,
attention_mask: Optional[Tensor] = None,
upsample_size: Optional[int] = None,
encoder_attention_mask: Optional[FloatTensor] = None,
encoder_attention_mask: Optional[Tensor] = None,
apply_control: bool = True,
) -> FloatTensor:
) -> Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......
......@@ -129,7 +129,7 @@ class Downsample2D(nn.Module):
else:
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -180,24 +180,24 @@ class FirDownsample2D(nn.Module):
def _downsample_2d(
self,
hidden_states: torch.FloatTensor,
weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
weight: Optional[torch.Tensor] = None,
kernel: Optional[torch.Tensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
) -> torch.Tensor:
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
hidden_states (`torch.FloatTensor`):
hidden_states (`torch.Tensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
weight (`torch.Tensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel (`torch.FloatTensor`, *optional*):
kernel (`torch.Tensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
......@@ -206,7 +206,7 @@ class FirDownsample2D(nn.Module):
Scaling factor for signal magnitude.
Returns:
output (`torch.FloatTensor`):
output (`torch.Tensor`):
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
"""
......@@ -244,7 +244,7 @@ class FirDownsample2D(nn.Module):
return output
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
......@@ -286,11 +286,11 @@ class KDownsample2D(nn.Module):
def downsample_2d(
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
kernel: Optional[torch.Tensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
) -> torch.Tensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
......@@ -298,9 +298,9 @@ def downsample_2d(
shape is a multiple of the downsampling factor.
Args:
hidden_states (`torch.FloatTensor`)
hidden_states (`torch.Tensor`)
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*):
kernel (`torch.Tensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
......@@ -309,7 +309,7 @@ def downsample_2d(
Scaling factor for signal magnitude.
Returns:
output (`torch.FloatTensor`):
output (`torch.Tensor`):
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
......
......@@ -424,7 +424,7 @@ class TextImageProjection(nn.Module):
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
batch_size = text_embeds.shape[0]
# image
......@@ -450,7 +450,7 @@ class ImageProjection(nn.Module):
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor):
def forward(self, image_embeds: torch.Tensor):
batch_size = image_embeds.shape[0]
# image
......@@ -468,7 +468,7 @@ class IPAdapterFullImageProjection(nn.Module):
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor):
def forward(self, image_embeds: torch.Tensor):
return self.norm(self.ff(image_embeds))
......@@ -482,7 +482,7 @@ class IPAdapterFaceIDImageProjection(nn.Module):
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor):
def forward(self, image_embeds: torch.Tensor):
x = self.ff(image_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
return self.norm(x)
......@@ -530,7 +530,7 @@ class TextImageTimeEmbedding(nn.Module):
self.text_norm = nn.LayerNorm(time_embed_dim)
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
# text
time_text_embeds = self.text_proj(text_embeds)
time_text_embeds = self.text_norm(time_text_embeds)
......@@ -547,7 +547,7 @@ class ImageTimeEmbedding(nn.Module):
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim)
def forward(self, image_embeds: torch.FloatTensor):
def forward(self, image_embeds: torch.Tensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
......@@ -577,7 +577,7 @@ class ImageHintTimeEmbedding(nn.Module):
nn.Conv2d(256, 4, 3, padding=1),
)
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
......@@ -1007,7 +1007,7 @@ class MultiIPAdapterImageProjection(nn.Module):
super().__init__()
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
def forward(self, image_embeds: List[torch.FloatTensor]):
def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds = []
# currently, we accept `image_embeds` as
......
......@@ -58,7 +58,7 @@ class ResnetBlockCondNorm2D(nn.Module):
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
kernel (`torch.Tensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
......@@ -146,7 +146,7 @@ class ResnetBlockCondNorm2D(nn.Module):
bias=conv_shortcut_bias,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -204,7 +204,7 @@ class ResnetBlock2D(nn.Module):
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
kernel (`torch.Tensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
......@@ -232,7 +232,7 @@ class ResnetBlock2D(nn.Module):
non_linearity: str = "swish",
skip_time_act: bool = False,
time_embedding_norm: str = "default", # default, scale_shift,
kernel: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.Tensor] = None,
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
up: bool = False,
......@@ -317,7 +317,7 @@ class ResnetBlock2D(nn.Module):
bias=conv_shortcut_bias,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -605,7 +605,7 @@ class TemporalResnetBlock(nn.Module):
padding=0,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
......@@ -685,8 +685,8 @@ class SpatioTemporalResBlock(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
):
num_frames = image_only_indicator.shape[-1]
......
......@@ -106,14 +106,13 @@ class DualTransformer2DModel(nn.Module):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states.
When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*):
attention_mask (`torch.Tensor`, *optional*):
Optional attention mask to be applied in Attention.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
......
......@@ -26,11 +26,11 @@ class PriorTransformerOutput(BaseOutput):
The output of [`PriorTransformer`].
Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
"""
predicted_image_embedding: torch.FloatTensor
predicted_image_embedding: torch.Tensor
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
......@@ -246,8 +246,8 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
self,
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
proj_embedding: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
......@@ -255,13 +255,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
The [`PriorTransformer`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
The currently predicted image embeddings.
timestep (`torch.LongTensor`):
Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
Hidden states of the text embeddings the denoising process is conditioned on.
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings.
......
......@@ -86,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3)
......@@ -195,13 +195,13 @@ class DecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
conditioning_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None,
) -> Tuple[torch.FloatTensor]:
) -> Tuple[torch.Tensor]:
hidden_states = self.layer[0](
hidden_states,
conditioning_emb=conditioning_emb,
......@@ -249,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
conditioning_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states)
......@@ -292,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
key_value_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
......@@ -328,9 +328,7 @@ class T5LayerFFCond(nn.Module):
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb)
......@@ -361,7 +359,7 @@ class T5DenseGatedActDense(nn.Module):
self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
......@@ -390,7 +388,7 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
......@@ -431,7 +429,7 @@ class T5FiLMLayer(nn.Module):
super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift
......
......@@ -35,12 +35,12 @@ class Transformer2DModelOutput(BaseOutput):
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
sample: torch.Tensor
class Transformer2DModel(ModelMixin, ConfigMixin):
......@@ -346,9 +346,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
......
......@@ -31,11 +31,11 @@ class TransformerTemporalModelOutput(BaseOutput):
The output of [`TransformerTemporalModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input.
"""
sample: torch.FloatTensor
sample: torch.Tensor
class TransformerTemporalModel(ModelMixin, ConfigMixin):
......@@ -120,7 +120,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
def forward(
self,
hidden_states: torch.FloatTensor,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: torch.LongTensor = None,
......@@ -132,7 +132,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
The [`TransformerTemporal`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
......@@ -283,7 +283,7 @@ class TransformerSpatioTemporalModel(nn.Module):
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input hidden_states.
num_frames (`int`):
The number of frames to be processed per batch. This is used to reshape the hidden states.
......
......@@ -31,11 +31,11 @@ class UNet1DOutput(BaseOutput):
The output of [`UNet1DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
sample: torch.Tensor
class UNet1DModel(ModelMixin, ConfigMixin):
......@@ -194,7 +194,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
......@@ -202,9 +202,9 @@ class UNet1DModel(ModelMixin, ConfigMixin):
The [`UNet1DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
......
......@@ -66,7 +66,7 @@ class DownResnetBlock1D(nn.Module):
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
......@@ -128,10 +128,10 @@ class UpResnetBlock1D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
res_hidden_states_tuple: Optional[Tuple[torch.Tensor, ...]] = None,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
......@@ -161,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module):
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
......@@ -209,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module):
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
......@@ -230,7 +230,7 @@ class OutConv1DBlock(nn.Module):
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
......@@ -251,7 +251,7 @@ class OutValueFunctionBlock(nn.Module):
]
)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
......@@ -288,7 +288,7 @@ class Downsample1d(nn.Module):
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
......@@ -305,7 +305,7 @@ class Upsample1d(nn.Module):
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
......@@ -335,7 +335,7 @@ class SelfAttention1d(nn.Module):
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
batch, channel_dim, seq = hidden_states.shape
......@@ -390,7 +390,7 @@ class ResConvBlock(nn.Module):
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states)
......@@ -435,7 +435,7 @@ class UNetMidBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
......@@ -466,7 +466,7 @@ class AttnDownBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions):
......@@ -490,7 +490,7 @@ class DownBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.down(hidden_states)
for resnet in self.resnets:
......@@ -512,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
......@@ -542,10 +542,10 @@ class AttnUpBlock1D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
......@@ -574,10 +574,10 @@ class UpBlock1D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
......@@ -604,10 +604,10 @@ class UpBlock1DNoSkip(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
......
......@@ -30,11 +30,11 @@ class UNet2DOutput(BaseOutput):
The output of [`UNet2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
sample: torch.Tensor
class UNet2DModel(ModelMixin, ConfigMixin):
......@@ -242,7 +242,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
......@@ -251,10 +251,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
The [`UNet2DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
......
......@@ -561,7 +561,7 @@ class AutoencoderTinyBlock(nn.Module):
` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
Returns:
`torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
`torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
`out_channels`.
"""
......@@ -582,7 +582,7 @@ class AutoencoderTinyBlock(nn.Module):
)
self.fuse = nn.ReLU()
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fuse(self.conv(x) + self.skip(x))
......@@ -612,8 +612,8 @@ class UNetMidBlock2D(nn.Module):
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
`torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
height, width)`.
"""
......@@ -731,7 +731,7 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
......@@ -846,13 +846,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -986,13 +986,13 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -1118,11 +1118,11 @@ class AttnDownBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -1240,14 +1240,14 @@ class CrossAttnDownBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
encoder_attention_mask: Optional[torch.Tensor] = None,
additional_residuals: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -1362,8 +1362,8 @@ class DownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1465,7 +1465,7 @@ class DownEncoderBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1567,7 +1567,7 @@ class AttnDownEncoderBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1666,12 +1666,12 @@ class AttnSkipDownBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
skip_sample: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
skip_sample: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1757,12 +1757,12 @@ class SkipDownBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
skip_sample: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
skip_sample: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1850,8 +1850,8 @@ class ResnetDownsampleBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1986,13 +1986,13 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -2097,8 +2097,8 @@ class KDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -2201,13 +2201,13 @@ class KCrossAttnDownBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -2358,13 +2358,13 @@ class AttnUpBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -2481,15 +2481,15 @@ class CrossAttnUpBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -2616,13 +2616,13 @@ class UpBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -2741,7 +2741,7 @@ class UpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb)
......@@ -2839,7 +2839,7 @@ class AttnUpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb)
hidden_states = attn(hidden_states, temb=temb)
......@@ -2947,13 +2947,13 @@ class AttnSkipUpBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
skip_sample=None,
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -3059,13 +3059,13 @@ class SkipUpBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
skip_sample=None,
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -3166,13 +3166,13 @@ class ResnetUpsampleBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -3310,15 +3310,15 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -3428,13 +3428,13 @@ class KUpBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -3558,15 +3558,15 @@ class KCrossAttnUpBlock2D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
......@@ -3684,23 +3684,23 @@ class KAttentionBlock(nn.Module):
cross_attention_norm=cross_attention_norm,
)
def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
# TODO: mark emb as non-optional (self.norm2 requires it).
# requires assessing impact of change to positional param interface.
emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......
......@@ -60,11 +60,11 @@ class UNet2DConditionOutput(BaseOutput):
The output of [`UNet2DConditionModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor = None
sample: torch.Tensor = None
class UNet2DConditionModel(
......@@ -1042,7 +1042,7 @@ class UNet2DConditionModel(
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
......@@ -1060,10 +1060,10 @@ class UNet2DConditionModel(
The [`UNet2DConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
......
......@@ -411,13 +411,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.FloatTensor:
) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
for attn, temp_attn, resnet, temp_conv in zip(
......@@ -544,13 +544,13 @@ class CrossAttnDownBlock3D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Dict[str, Any] = None,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
# TODO(Patrick, William) - attention mask is not used
output_states = ()
......@@ -651,10 +651,10 @@ class DownBlock3D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
num_frames: int = 1,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = ()
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
......@@ -769,15 +769,15 @@ class CrossAttnUpBlock3D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Dict[str, Any] = None,
) -> torch.FloatTensor:
) -> torch.Tensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
......@@ -891,12 +891,12 @@ class UpBlock3D(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
) -> torch.Tensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
......@@ -1008,12 +1008,12 @@ class DownBlockMotion(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
num_frames: int = 1,
*args,
**kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1174,14 +1174,14 @@ class CrossAttnDownBlockMotion(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
additional_residuals: Optional[torch.Tensor] = None,
):
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
......@@ -1357,16 +1357,16 @@ class CrossAttnUpBlockMotion(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -1518,14 +1518,14 @@ class UpBlockMotion(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
upsample_size=None,
num_frames: int = 1,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1699,14 +1699,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -1811,8 +1811,8 @@ class MidBlockTemporalDecoder(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
hidden_states: torch.Tensor,
image_only_indicator: torch.Tensor,
):
hidden_states = self.resnets[0](
hidden_states,
......@@ -1862,9 +1862,9 @@ class UpBlockTemporalDecoder(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
image_only_indicator: torch.Tensor,
) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(
hidden_states,
......@@ -1935,11 +1935,11 @@ class UNetMidBlockSpatioTemporal(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
) -> torch.Tensor:
hidden_states = self.resnets[0](
hidden_states,
temb,
......@@ -2031,10 +2031,10 @@ class DownBlockSpatioTemporal(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
......@@ -2141,11 +2141,11 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = ()
blocks = list(zip(self.resnets, self.attentions))
......@@ -2240,11 +2240,11 @@ class UpBlockSpatioTemporal(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
) -> torch.Tensor:
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......@@ -2349,12 +2349,12 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment