Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
...@@ -86,10 +86,10 @@ class TemporalDecoder(nn.Module): ...@@ -86,10 +86,10 @@ class TemporalDecoder(nn.Module):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
image_only_indicator: torch.FloatTensor, image_only_indicator: torch.Tensor,
num_frames: int = 1, num_frames: int = 1,
) -> torch.FloatTensor: ) -> torch.Tensor:
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -315,13 +315,13 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): ...@@ -315,13 +315,13 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
@apply_forward_hook @apply_forward_hook
def encode( def encode(
self, x: torch.FloatTensor, return_dict: bool = True self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
""" """
Encode a batch of images into latents. Encode a batch of images into latents.
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
...@@ -341,15 +341,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): ...@@ -341,15 +341,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
@apply_forward_hook @apply_forward_hook
def decode( def decode(
self, self,
z: torch.FloatTensor, z: torch.Tensor,
num_frames: int, num_frames: int,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, torch.Tensor]:
""" """
Decode a batch of images. Decode a batch of images.
Args: Args:
z (`torch.FloatTensor`): Input batch of latent vectors. z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
...@@ -370,15 +370,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): ...@@ -370,15 +370,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
sample_posterior: bool = False, sample_posterior: bool = False,
return_dict: bool = True, return_dict: bool = True,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
num_frames: int = 1, num_frames: int = 1,
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, torch.Tensor]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`): sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior. Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
......
...@@ -157,11 +157,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -157,11 +157,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
if isinstance(module, (EncoderTiny, DecoderTiny)): if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor: def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]""" """raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor: def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""[0, 1] -> raw latents""" """[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
...@@ -194,7 +194,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -194,7 +194,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
""" """
self.enable_tiling(False) self.enable_tiling(False)
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor: def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder. r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
...@@ -202,10 +202,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -202,10 +202,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
tiles overlap and are blended together to form a smooth output. tiles overlap and are blended together to form a smooth output.
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
Returns: Returns:
`torch.FloatTensor`: Encoded batch of images. `torch.Tensor`: Encoded batch of images.
""" """
# scale of encoder output relative to input # scale of encoder output relative to input
sf = self.spatial_scale_factor sf = self.spatial_scale_factor
...@@ -242,7 +242,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -242,7 +242,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
return out return out
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor: def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder. r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
...@@ -250,10 +250,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -250,10 +250,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
tiles overlap and are blended together to form a smooth output. tiles overlap and are blended together to form a smooth output.
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
Returns: Returns:
`torch.FloatTensor`: Encoded batch of images. `torch.Tensor`: Encoded batch of images.
""" """
# scale of decoder output relative to input # scale of decoder output relative to input
sf = self.spatial_scale_factor sf = self.spatial_scale_factor
...@@ -290,9 +290,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -290,9 +290,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
return out return out
@apply_forward_hook @apply_forward_hook
def encode( def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
if self.use_slicing and x.shape[0] > 1: if self.use_slicing and x.shape[0] > 1:
output = [ output = [
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1) self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
...@@ -308,8 +306,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -308,8 +306,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
@apply_forward_hook @apply_forward_hook
def decode( def decode(
self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
if self.use_slicing and x.shape[0] > 1: if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
output = torch.cat(output) output = torch.cat(output)
...@@ -323,12 +321,12 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -323,12 +321,12 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
""" """
......
...@@ -276,13 +276,13 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -276,13 +276,13 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
@apply_forward_hook @apply_forward_hook
def encode( def encode(
self, x: torch.FloatTensor, return_dict: bool = True self, x: torch.Tensor, return_dict: bool = True
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]: ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
""" """
Encode a batch of images into latents. Encode a batch of images into latents.
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
tuple. tuple.
...@@ -312,22 +312,22 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -312,22 +312,22 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
@apply_forward_hook @apply_forward_hook
def decode( def decode(
self, self,
z: torch.FloatTensor, z: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
num_inference_steps: int = 2, num_inference_steps: int = 2,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
""" """
Decodes the input latent vector `z` using the consistency decoder VAE model. Decodes the input latent vector `z` using the consistency decoder VAE model.
Args: Args:
z (torch.FloatTensor): The input latent vector. z (torch.Tensor): The input latent vector.
generator (Optional[torch.Generator]): The random number generator. Default is None. generator (Optional[torch.Generator]): The random number generator. Default is None.
return_dict (bool): Whether to return the output as a dictionary. Default is True. return_dict (bool): Whether to return the output as a dictionary. Default is True.
num_inference_steps (int): The number of inference steps. Default is 2. num_inference_steps (int): The number of inference steps. Default is 2.
Returns: Returns:
Union[DecoderOutput, Tuple[torch.FloatTensor]]: The decoded output. Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
""" """
z = (z * self.config.scaling_factor - self.means) / self.stds z = (z * self.config.scaling_factor - self.means) / self.stds
...@@ -370,9 +370,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -370,9 +370,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b return b
def tiled_encode( def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
r"""Encode a batch of images using a tiled encoder. r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
...@@ -382,7 +380,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -382,7 +380,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
output, but they should be much less noticeable. output, but they should be much less noticeable.
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
plain tuple. plain tuple.
...@@ -429,14 +427,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -429,14 +427,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
sample_posterior: bool = False, sample_posterior: bool = False,
return_dict: bool = True, return_dict: bool = True,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`): sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior. Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
......
...@@ -36,11 +36,11 @@ class DecoderOutput(BaseOutput): ...@@ -36,11 +36,11 @@ class DecoderOutput(BaseOutput):
Output of decoding method. Output of decoding method.
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model. The decoded output sample from the last layer of the model.
""" """
sample: torch.FloatTensor sample: torch.Tensor
class Encoder(nn.Module): class Encoder(nn.Module):
...@@ -136,7 +136,7 @@ class Encoder(nn.Module): ...@@ -136,7 +136,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: def forward(self, sample: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `Encoder` class.""" r"""The forward method of the `Encoder` class."""
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -282,9 +282,9 @@ class Decoder(nn.Module): ...@@ -282,9 +282,9 @@ class Decoder(nn.Module):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
latent_embeds: Optional[torch.FloatTensor] = None, latent_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -367,7 +367,7 @@ class UpSample(nn.Module): ...@@ -367,7 +367,7 @@ class UpSample(nn.Module):
self.out_channels = out_channels self.out_channels = out_channels
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `UpSample` class.""" r"""The forward method of the `UpSample` class."""
x = torch.relu(x) x = torch.relu(x)
x = self.deconv(x) x = self.deconv(x)
...@@ -416,7 +416,7 @@ class MaskConditionEncoder(nn.Module): ...@@ -416,7 +416,7 @@ class MaskConditionEncoder(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor: def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
r"""The forward method of the `MaskConditionEncoder` class.""" r"""The forward method of the `MaskConditionEncoder` class."""
out = {} out = {}
for l in range(len(self.layers)): for l in range(len(self.layers)):
...@@ -533,11 +533,11 @@ class MaskConditionDecoder(nn.Module): ...@@ -533,11 +533,11 @@ class MaskConditionDecoder(nn.Module):
def forward( def forward(
self, self,
z: torch.FloatTensor, z: torch.Tensor,
image: Optional[torch.FloatTensor] = None, image: Optional[torch.Tensor] = None,
mask: Optional[torch.FloatTensor] = None, mask: Optional[torch.Tensor] = None,
latent_embeds: Optional[torch.FloatTensor] = None, latent_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
r"""The forward method of the `MaskConditionDecoder` class.""" r"""The forward method of the `MaskConditionDecoder` class."""
sample = z sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -711,7 +711,7 @@ class VectorQuantizer(nn.Module): ...@@ -711,7 +711,7 @@ class VectorQuantizer(nn.Module):
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape) return back.reshape(ishape)
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]: def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten # reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous() z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim) z_flattened = z.view(-1, self.vq_embed_dim)
...@@ -730,7 +730,7 @@ class VectorQuantizer(nn.Module): ...@@ -730,7 +730,7 @@ class VectorQuantizer(nn.Module):
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients # preserve gradients
z_q: torch.FloatTensor = z + (z_q - z).detach() z_q: torch.Tensor = z + (z_q - z).detach()
# reshape back to match original input shape # reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous() z_q = z_q.permute(0, 3, 1, 2).contiguous()
...@@ -745,7 +745,7 @@ class VectorQuantizer(nn.Module): ...@@ -745,7 +745,7 @@ class VectorQuantizer(nn.Module):
return z_q, loss, (perplexity, min_encodings, min_encoding_indices) return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor: def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
# shape specifying (batch, height, width, channel) # shape specifying (batch, height, width, channel)
if self.remap is not None: if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis indices = indices.reshape(shape[0], -1) # add batch axis
...@@ -753,7 +753,7 @@ class VectorQuantizer(nn.Module): ...@@ -753,7 +753,7 @@ class VectorQuantizer(nn.Module):
indices = indices.reshape(-1) # flatten again indices = indices.reshape(-1) # flatten again
# get quantized latent vectors # get quantized latent vectors
z_q: torch.FloatTensor = self.embedding(indices) z_q: torch.Tensor = self.embedding(indices)
if shape is not None: if shape is not None:
z_q = z_q.view(shape) z_q = z_q.view(shape)
...@@ -776,7 +776,7 @@ class DiagonalGaussianDistribution(object): ...@@ -776,7 +776,7 @@ class DiagonalGaussianDistribution(object):
self.mean, device=self.parameters.device, dtype=self.parameters.dtype self.mean, device=self.parameters.device, dtype=self.parameters.dtype
) )
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype # make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor( sample = randn_tensor(
self.mean.shape, self.mean.shape,
...@@ -873,7 +873,7 @@ class EncoderTiny(nn.Module): ...@@ -873,7 +873,7 @@ class EncoderTiny(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class.""" r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -956,7 +956,7 @@ class DecoderTiny(nn.Module): ...@@ -956,7 +956,7 @@ class DecoderTiny(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `DecoderTiny` class.""" r"""The forward method of the `DecoderTiny` class."""
# Clamp. # Clamp.
x = torch.tanh(x / 3) * 3 x = torch.tanh(x / 3) * 3
......
...@@ -665,10 +665,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -665,10 +665,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor, controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0, conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
...@@ -677,18 +677,18 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -677,18 +677,18 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False, guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
""" """
The [`ControlNetModel`] forward method. The [`ControlNetModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The noisy input tensor. The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`): timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input. The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`): encoder_hidden_states (`torch.Tensor`):
The encoder hidden states. The encoder hidden states.
controlnet_cond (`torch.FloatTensor`): controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`): conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs. The scale factor for ControlNet outputs.
......
...@@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import FloatTensor, nn from torch import Tensor, nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, is_torch_version, logging from ..utils import BaseOutput, is_torch_version, logging
...@@ -54,12 +54,12 @@ class ControlNetXSOutput(BaseOutput): ...@@ -54,12 +54,12 @@ class ControlNetXSOutput(BaseOutput):
The output of [`UNetControlNetXSModel`]. The output of [`UNetControlNetXSModel`].
Args: Args:
sample (`FloatTensor` of shape `(batch_size, num_channels, height, width)`): sample (`Tensor` of shape `(batch_size, num_channels, height, width)`):
The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base
model output, but is already the final output. model output, but is already the final output.
""" """
sample: FloatTensor = None sample: Tensor = None
class DownBlockControlNetXSAdapter(nn.Module): class DownBlockControlNetXSAdapter(nn.Module):
...@@ -1001,7 +1001,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -1001,7 +1001,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
sample: FloatTensor, sample: Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
controlnet_cond: Optional[torch.Tensor] = None, controlnet_cond: Optional[torch.Tensor] = None,
...@@ -1018,13 +1018,13 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -1018,13 +1018,13 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
The [`ControlNetXSModel`] forward method. The [`ControlNetXSModel`] forward method.
Args: Args:
sample (`FloatTensor`): sample (`Tensor`):
The noisy input tensor. The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`): timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input. The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`): encoder_hidden_states (`torch.Tensor`):
The encoder hidden states. The encoder hidden states.
controlnet_cond (`FloatTensor`): controlnet_cond (`Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`): conditioning_scale (`float`, defaults to `1.0`):
How much the control model affects the base model outputs. How much the control model affects the base model outputs.
...@@ -1402,16 +1402,16 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1402,16 +1402,16 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states_base: FloatTensor, hidden_states_base: Tensor,
temb: FloatTensor, temb: Tensor,
encoder_hidden_states: Optional[FloatTensor] = None, encoder_hidden_states: Optional[Tensor] = None,
hidden_states_ctrl: Optional[FloatTensor] = None, hidden_states_ctrl: Optional[Tensor] = None,
conditioning_scale: Optional[float] = 1.0, conditioning_scale: Optional[float] = 1.0,
attention_mask: Optional[FloatTensor] = None, attention_mask: Optional[Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[FloatTensor] = None, encoder_attention_mask: Optional[Tensor] = None,
apply_control: bool = True, apply_control: bool = True,
) -> Tuple[FloatTensor, FloatTensor, Tuple[FloatTensor, ...], Tuple[FloatTensor, ...]]: ) -> Tuple[Tensor, Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -1626,16 +1626,16 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ...@@ -1626,16 +1626,16 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states_base: FloatTensor, hidden_states_base: Tensor,
temb: FloatTensor, temb: Tensor,
encoder_hidden_states: FloatTensor, encoder_hidden_states: Tensor,
hidden_states_ctrl: Optional[FloatTensor] = None, hidden_states_ctrl: Optional[Tensor] = None,
conditioning_scale: Optional[float] = 1.0, conditioning_scale: Optional[float] = 1.0,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[FloatTensor] = None, attention_mask: Optional[Tensor] = None,
encoder_attention_mask: Optional[FloatTensor] = None, encoder_attention_mask: Optional[Tensor] = None,
apply_control: bool = True, apply_control: bool = True,
) -> Tuple[FloatTensor, FloatTensor]: ) -> Tuple[Tensor, Tensor]:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -1807,18 +1807,18 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): ...@@ -1807,18 +1807,18 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: FloatTensor, hidden_states: Tensor,
res_hidden_states_tuple_base: Tuple[FloatTensor, ...], res_hidden_states_tuple_base: Tuple[Tensor, ...],
res_hidden_states_tuple_ctrl: Tuple[FloatTensor, ...], res_hidden_states_tuple_ctrl: Tuple[Tensor, ...],
temb: FloatTensor, temb: Tensor,
encoder_hidden_states: Optional[FloatTensor] = None, encoder_hidden_states: Optional[Tensor] = None,
conditioning_scale: Optional[float] = 1.0, conditioning_scale: Optional[float] = 1.0,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[FloatTensor] = None, attention_mask: Optional[Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
encoder_attention_mask: Optional[FloatTensor] = None, encoder_attention_mask: Optional[Tensor] = None,
apply_control: bool = True, apply_control: bool = True,
) -> FloatTensor: ) -> Tensor:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......
...@@ -129,7 +129,7 @@ class Downsample2D(nn.Module): ...@@ -129,7 +129,7 @@ class Downsample2D(nn.Module):
else: else:
self.conv = conv self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -180,24 +180,24 @@ class FirDownsample2D(nn.Module): ...@@ -180,24 +180,24 @@ class FirDownsample2D(nn.Module):
def _downsample_2d( def _downsample_2d(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
weight: Optional[torch.FloatTensor] = None, weight: Optional[torch.Tensor] = None,
kernel: Optional[torch.FloatTensor] = None, kernel: Optional[torch.Tensor] = None,
factor: int = 2, factor: int = 2,
gain: float = 1, gain: float = 1,
) -> torch.FloatTensor: ) -> torch.Tensor:
"""Fused `Conv2d()` followed by `downsample_2d()`. """Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order. arbitrary order.
Args: Args:
hidden_states (`torch.FloatTensor`): hidden_states (`torch.Tensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*): weight (`torch.Tensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`. performed by `inChannels = x.shape[0] // numGroups`.
kernel (`torch.FloatTensor`, *optional*): kernel (`torch.Tensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling. corresponds to average pooling.
factor (`int`, *optional*, default to `2`): factor (`int`, *optional*, default to `2`):
...@@ -206,7 +206,7 @@ class FirDownsample2D(nn.Module): ...@@ -206,7 +206,7 @@ class FirDownsample2D(nn.Module):
Scaling factor for signal magnitude. Scaling factor for signal magnitude.
Returns: Returns:
output (`torch.FloatTensor`): output (`torch.Tensor`):
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`. datatype as `x`.
""" """
...@@ -244,7 +244,7 @@ class FirDownsample2D(nn.Module): ...@@ -244,7 +244,7 @@ class FirDownsample2D(nn.Module):
return output return output
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_conv: if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
...@@ -286,11 +286,11 @@ class KDownsample2D(nn.Module): ...@@ -286,11 +286,11 @@ class KDownsample2D(nn.Module):
def downsample_2d( def downsample_2d(
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
kernel: Optional[torch.FloatTensor] = None, kernel: Optional[torch.Tensor] = None,
factor: int = 2, factor: int = 2,
gain: float = 1, gain: float = 1,
) -> torch.FloatTensor: ) -> torch.Tensor:
r"""Downsample2D a batch of 2D images with the given filter. r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
...@@ -298,9 +298,9 @@ def downsample_2d( ...@@ -298,9 +298,9 @@ def downsample_2d(
shape is a multiple of the downsampling factor. shape is a multiple of the downsampling factor.
Args: Args:
hidden_states (`torch.FloatTensor`) hidden_states (`torch.Tensor`)
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*): kernel (`torch.Tensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling. corresponds to average pooling.
factor (`int`, *optional*, default to `2`): factor (`int`, *optional*, default to `2`):
...@@ -309,7 +309,7 @@ def downsample_2d( ...@@ -309,7 +309,7 @@ def downsample_2d(
Scaling factor for signal magnitude. Scaling factor for signal magnitude.
Returns: Returns:
output (`torch.FloatTensor`): output (`torch.Tensor`):
Tensor of the shape `[N, C, H // factor, W // factor]` Tensor of the shape `[N, C, H // factor, W // factor]`
""" """
......
...@@ -424,7 +424,7 @@ class TextImageProjection(nn.Module): ...@@ -424,7 +424,7 @@ class TextImageProjection(nn.Module):
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
batch_size = text_embeds.shape[0] batch_size = text_embeds.shape[0]
# image # image
...@@ -450,7 +450,7 @@ class ImageProjection(nn.Module): ...@@ -450,7 +450,7 @@ class ImageProjection(nn.Module):
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor): def forward(self, image_embeds: torch.Tensor):
batch_size = image_embeds.shape[0] batch_size = image_embeds.shape[0]
# image # image
...@@ -468,7 +468,7 @@ class IPAdapterFullImageProjection(nn.Module): ...@@ -468,7 +468,7 @@ class IPAdapterFullImageProjection(nn.Module):
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor): def forward(self, image_embeds: torch.Tensor):
return self.norm(self.ff(image_embeds)) return self.norm(self.ff(image_embeds))
...@@ -482,7 +482,7 @@ class IPAdapterFaceIDImageProjection(nn.Module): ...@@ -482,7 +482,7 @@ class IPAdapterFaceIDImageProjection(nn.Module):
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor): def forward(self, image_embeds: torch.Tensor):
x = self.ff(image_embeds) x = self.ff(image_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
return self.norm(x) return self.norm(x)
...@@ -530,7 +530,7 @@ class TextImageTimeEmbedding(nn.Module): ...@@ -530,7 +530,7 @@ class TextImageTimeEmbedding(nn.Module):
self.text_norm = nn.LayerNorm(time_embed_dim) self.text_norm = nn.LayerNorm(time_embed_dim)
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
# text # text
time_text_embeds = self.text_proj(text_embeds) time_text_embeds = self.text_proj(text_embeds)
time_text_embeds = self.text_norm(time_text_embeds) time_text_embeds = self.text_norm(time_text_embeds)
...@@ -547,7 +547,7 @@ class ImageTimeEmbedding(nn.Module): ...@@ -547,7 +547,7 @@ class ImageTimeEmbedding(nn.Module):
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim)
def forward(self, image_embeds: torch.FloatTensor): def forward(self, image_embeds: torch.Tensor):
# image # image
time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds) time_image_embeds = self.image_norm(time_image_embeds)
...@@ -577,7 +577,7 @@ class ImageHintTimeEmbedding(nn.Module): ...@@ -577,7 +577,7 @@ class ImageHintTimeEmbedding(nn.Module):
nn.Conv2d(256, 4, 3, padding=1), nn.Conv2d(256, 4, 3, padding=1),
) )
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
# image # image
time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds) time_image_embeds = self.image_norm(time_image_embeds)
...@@ -1007,7 +1007,7 @@ class MultiIPAdapterImageProjection(nn.Module): ...@@ -1007,7 +1007,7 @@ class MultiIPAdapterImageProjection(nn.Module):
super().__init__() super().__init__()
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
def forward(self, image_embeds: List[torch.FloatTensor]): def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds = [] projected_image_embeds = []
# currently, we accept `image_embeds` as # currently, we accept `image_embeds` as
......
...@@ -58,7 +58,7 @@ class ResnetBlockCondNorm2D(nn.Module): ...@@ -58,7 +58,7 @@ class ResnetBlockCondNorm2D(nn.Module):
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"ada_group"` ): time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial". The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see kernel (`torch.Tensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`): use_in_shortcut (`bool`, *optional*, default to `True`):
...@@ -146,7 +146,7 @@ class ResnetBlockCondNorm2D(nn.Module): ...@@ -146,7 +146,7 @@ class ResnetBlockCondNorm2D(nn.Module):
bias=conv_shortcut_bias, bias=conv_shortcut_bias,
) )
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -204,7 +204,7 @@ class ResnetBlock2D(nn.Module): ...@@ -204,7 +204,7 @@ class ResnetBlock2D(nn.Module):
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
stronger conditioning with scale and shift. stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see kernel (`torch.Tensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`): use_in_shortcut (`bool`, *optional*, default to `True`):
...@@ -232,7 +232,7 @@ class ResnetBlock2D(nn.Module): ...@@ -232,7 +232,7 @@ class ResnetBlock2D(nn.Module):
non_linearity: str = "swish", non_linearity: str = "swish",
skip_time_act: bool = False, skip_time_act: bool = False,
time_embedding_norm: str = "default", # default, scale_shift, time_embedding_norm: str = "default", # default, scale_shift,
kernel: Optional[torch.FloatTensor] = None, kernel: Optional[torch.Tensor] = None,
output_scale_factor: float = 1.0, output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None, use_in_shortcut: Optional[bool] = None,
up: bool = False, up: bool = False,
...@@ -317,7 +317,7 @@ class ResnetBlock2D(nn.Module): ...@@ -317,7 +317,7 @@ class ResnetBlock2D(nn.Module):
bias=conv_shortcut_bias, bias=conv_shortcut_bias,
) )
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -605,7 +605,7 @@ class TemporalResnetBlock(nn.Module): ...@@ -605,7 +605,7 @@ class TemporalResnetBlock(nn.Module):
padding=0, padding=0,
) )
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
hidden_states = input_tensor hidden_states = input_tensor
hidden_states = self.norm1(hidden_states) hidden_states = self.norm1(hidden_states)
...@@ -685,8 +685,8 @@ class SpatioTemporalResBlock(nn.Module): ...@@ -685,8 +685,8 @@ class SpatioTemporalResBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
): ):
num_frames = image_only_indicator.shape[-1] num_frames = image_only_indicator.shape[-1]
......
...@@ -106,14 +106,13 @@ class DualTransformer2DModel(nn.Module): ...@@ -106,14 +106,13 @@ class DualTransformer2DModel(nn.Module):
""" """
Args: Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states.
hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention. self-attention.
timestep ( `torch.long`, *optional*): timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*): attention_mask (`torch.Tensor`, *optional*):
Optional attention mask to be applied in Attention. Optional attention mask to be applied in Attention.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
......
...@@ -26,11 +26,11 @@ class PriorTransformerOutput(BaseOutput): ...@@ -26,11 +26,11 @@ class PriorTransformerOutput(BaseOutput):
The output of [`PriorTransformer`]. The output of [`PriorTransformer`].
Args: Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input. The predicted CLIP image embedding conditioned on the CLIP text embedding input.
""" """
predicted_image_embedding: torch.FloatTensor predicted_image_embedding: torch.Tensor
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
...@@ -246,8 +246,8 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef ...@@ -246,8 +246,8 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
self, self,
hidden_states, hidden_states,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor, proj_embedding: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True, return_dict: bool = True,
): ):
...@@ -255,13 +255,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef ...@@ -255,13 +255,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
The [`PriorTransformer`] forward method. The [`PriorTransformer`] forward method.
Args: Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
The currently predicted image embeddings. The currently predicted image embeddings.
timestep (`torch.LongTensor`): timestep (`torch.LongTensor`):
Current denoising step. Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on. Projected embedding vector the denoising process is conditioned on.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
Hidden states of the text embeddings the denoising process is conditioned on. Hidden states of the text embeddings the denoising process is conditioned on.
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings. Text mask for the text embeddings.
......
...@@ -86,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin): ...@@ -86,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
self.post_dropout = nn.Dropout(p=dropout_rate) self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False) self.spec_out = nn.Linear(d_model, input_dims, bias=False)
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor: def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3) return mask.unsqueeze(-3)
...@@ -195,13 +195,13 @@ class DecoderLayer(nn.Module): ...@@ -195,13 +195,13 @@ class DecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
conditioning_emb: Optional[torch.FloatTensor] = None, conditioning_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
) -> Tuple[torch.FloatTensor]: ) -> Tuple[torch.Tensor]:
hidden_states = self.layer[0]( hidden_states = self.layer[0](
hidden_states, hidden_states,
conditioning_emb=conditioning_emb, conditioning_emb=conditioning_emb,
...@@ -249,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module): ...@@ -249,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
conditioning_emb: Optional[torch.FloatTensor] = None, conditioning_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
# pre_self_attention_layer_norm # pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
...@@ -292,10 +292,10 @@ class T5LayerCrossAttention(nn.Module): ...@@ -292,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.FloatTensor] = None, key_value_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention( attention_output = self.attention(
normed_hidden_states, normed_hidden_states,
...@@ -328,9 +328,7 @@ class T5LayerFFCond(nn.Module): ...@@ -328,9 +328,7 @@ class T5LayerFFCond(nn.Module):
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
def forward( def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None: if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb) forwarded_states = self.film(forwarded_states, conditioning_emb)
...@@ -361,7 +359,7 @@ class T5DenseGatedActDense(nn.Module): ...@@ -361,7 +359,7 @@ class T5DenseGatedActDense(nn.Module):
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation() self.act = NewGELUActivation()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states) hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear hidden_states = hidden_gelu * hidden_linear
...@@ -390,7 +388,7 @@ class T5LayerNorm(nn.Module): ...@@ -390,7 +388,7 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
...@@ -431,7 +429,7 @@ class T5FiLMLayer(nn.Module): ...@@ -431,7 +429,7 @@ class T5FiLMLayer(nn.Module):
super().__init__() super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
emb = self.scale_bias(conditioning_emb) emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1) scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
......
...@@ -35,12 +35,12 @@ class Transformer2DModelOutput(BaseOutput): ...@@ -35,12 +35,12 @@ class Transformer2DModelOutput(BaseOutput):
The output of [`Transformer2DModel`]. The output of [`Transformer2DModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels. distributions for the unnoised latent pixels.
""" """
sample: torch.FloatTensor sample: torch.Tensor
class Transformer2DModel(ModelMixin, ConfigMixin): class Transformer2DModel(ModelMixin, ConfigMixin):
...@@ -346,9 +346,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -346,9 +346,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
The [`Transformer2DModel`] forward method. The [`Transformer2DModel`] forward method.
Args: Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`. Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention. self-attention.
timestep ( `torch.LongTensor`, *optional*): timestep ( `torch.LongTensor`, *optional*):
......
...@@ -31,11 +31,11 @@ class TransformerTemporalModelOutput(BaseOutput): ...@@ -31,11 +31,11 @@ class TransformerTemporalModelOutput(BaseOutput):
The output of [`TransformerTemporalModel`]. The output of [`TransformerTemporalModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. The hidden states output conditioned on `encoder_hidden_states` input.
""" """
sample: torch.FloatTensor sample: torch.Tensor
class TransformerTemporalModel(ModelMixin, ConfigMixin): class TransformerTemporalModel(ModelMixin, ConfigMixin):
...@@ -120,7 +120,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -120,7 +120,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None,
class_labels: torch.LongTensor = None, class_labels: torch.LongTensor = None,
...@@ -132,7 +132,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -132,7 +132,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
The [`TransformerTemporal`] forward method. The [`TransformerTemporal`] forward method.
Args: Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
Input hidden_states. Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
...@@ -283,7 +283,7 @@ class TransformerSpatioTemporalModel(nn.Module): ...@@ -283,7 +283,7 @@ class TransformerSpatioTemporalModel(nn.Module):
): ):
""" """
Args: Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input hidden_states. Input hidden_states.
num_frames (`int`): num_frames (`int`):
The number of frames to be processed per batch. This is used to reshape the hidden states. The number of frames to be processed per batch. This is used to reshape the hidden states.
......
...@@ -31,11 +31,11 @@ class UNet1DOutput(BaseOutput): ...@@ -31,11 +31,11 @@ class UNet1DOutput(BaseOutput):
The output of [`UNet1DModel`]. The output of [`UNet1DModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`):
The hidden states output from the last layer of the model. The hidden states output from the last layer of the model.
""" """
sample: torch.FloatTensor sample: torch.Tensor
class UNet1DModel(ModelMixin, ConfigMixin): class UNet1DModel(ModelMixin, ConfigMixin):
...@@ -194,7 +194,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -194,7 +194,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]: ) -> Union[UNet1DOutput, Tuple]:
...@@ -202,9 +202,9 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -202,9 +202,9 @@ class UNet1DModel(ModelMixin, ConfigMixin):
The [`UNet1DModel`] forward method. The [`UNet1DModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
......
...@@ -66,7 +66,7 @@ class DownResnetBlock1D(nn.Module): ...@@ -66,7 +66,7 @@ class DownResnetBlock1D(nn.Module):
if add_downsample: if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
output_states = () output_states = ()
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
...@@ -128,10 +128,10 @@ class UpResnetBlock1D(nn.Module): ...@@ -128,10 +128,10 @@ class UpResnetBlock1D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None, res_hidden_states_tuple: Optional[Tuple[torch.Tensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
if res_hidden_states_tuple is not None: if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
...@@ -161,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module): ...@@ -161,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module):
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True) self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.res1(x, temb) x = self.res1(x, temb)
x = self.down1(x) x = self.down1(x)
x = self.res2(x, temb) x = self.res2(x, temb)
...@@ -209,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module): ...@@ -209,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module):
if self.upsample and self.downsample: if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample") raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]: for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -230,7 +230,7 @@ class OutConv1DBlock(nn.Module): ...@@ -230,7 +230,7 @@ class OutConv1DBlock(nn.Module):
self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.final_conv1d_1(hidden_states) hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states) hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states) hidden_states = self.final_conv1d_gn(hidden_states)
...@@ -251,7 +251,7 @@ class OutValueFunctionBlock(nn.Module): ...@@ -251,7 +251,7 @@ class OutValueFunctionBlock(nn.Module):
] ]
) )
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1) hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1) hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block: for layer in self.final_block:
...@@ -288,7 +288,7 @@ class Downsample1d(nn.Module): ...@@ -288,7 +288,7 @@ class Downsample1d(nn.Module):
self.pad = kernel_1d.shape[0] // 2 - 1 self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d) self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
...@@ -305,7 +305,7 @@ class Upsample1d(nn.Module): ...@@ -305,7 +305,7 @@ class Upsample1d(nn.Module):
self.pad = kernel_1d.shape[0] // 2 - 1 self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d) self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
...@@ -335,7 +335,7 @@ class SelfAttention1d(nn.Module): ...@@ -335,7 +335,7 @@ class SelfAttention1d(nn.Module):
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection return new_projection
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states residual = hidden_states
batch, channel_dim, seq = hidden_states.shape batch, channel_dim, seq = hidden_states.shape
...@@ -390,7 +390,7 @@ class ResConvBlock(nn.Module): ...@@ -390,7 +390,7 @@ class ResConvBlock(nn.Module):
self.group_norm_2 = nn.GroupNorm(1, out_channels) self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU() self.gelu_2 = nn.GELU()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states) hidden_states = self.conv_1(hidden_states)
...@@ -435,7 +435,7 @@ class UNetMidBlock1D(nn.Module): ...@@ -435,7 +435,7 @@ class UNetMidBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets): for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
...@@ -466,7 +466,7 @@ class AttnDownBlock1D(nn.Module): ...@@ -466,7 +466,7 @@ class AttnDownBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
...@@ -490,7 +490,7 @@ class DownBlock1D(nn.Module): ...@@ -490,7 +490,7 @@ class DownBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for resnet in self.resnets: for resnet in self.resnets:
...@@ -512,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module): ...@@ -512,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = torch.cat([hidden_states, temb], dim=1) hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
...@@ -542,10 +542,10 @@ class AttnUpBlock1D(nn.Module): ...@@ -542,10 +542,10 @@ class AttnUpBlock1D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -574,10 +574,10 @@ class UpBlock1D(nn.Module): ...@@ -574,10 +574,10 @@ class UpBlock1D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -604,10 +604,10 @@ class UpBlock1DNoSkip(nn.Module): ...@@ -604,10 +604,10 @@ class UpBlock1DNoSkip(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
......
...@@ -30,11 +30,11 @@ class UNet2DOutput(BaseOutput): ...@@ -30,11 +30,11 @@ class UNet2DOutput(BaseOutput):
The output of [`UNet2DModel`]. The output of [`UNet2DModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output from the last layer of the model. The hidden states output from the last layer of the model.
""" """
sample: torch.FloatTensor sample: torch.Tensor
class UNet2DModel(ModelMixin, ConfigMixin): class UNet2DModel(ModelMixin, ConfigMixin):
...@@ -242,7 +242,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -242,7 +242,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
...@@ -251,10 +251,10 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -251,10 +251,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
The [`UNet2DModel`] forward method. The [`UNet2DModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`. The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
......
...@@ -561,7 +561,7 @@ class AutoencoderTinyBlock(nn.Module): ...@@ -561,7 +561,7 @@ class AutoencoderTinyBlock(nn.Module):
` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
Returns: Returns:
`torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to `torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
`out_channels`. `out_channels`.
""" """
...@@ -582,7 +582,7 @@ class AutoencoderTinyBlock(nn.Module): ...@@ -582,7 +582,7 @@ class AutoencoderTinyBlock(nn.Module):
) )
self.fuse = nn.ReLU() self.fuse = nn.ReLU()
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))
...@@ -612,8 +612,8 @@ class UNetMidBlock2D(nn.Module): ...@@ -612,8 +612,8 @@ class UNetMidBlock2D(nn.Module):
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns: Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
in_channels, height, width)`. height, width)`.
""" """
...@@ -731,7 +731,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -731,7 +731,7 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None: if attn is not None:
...@@ -846,13 +846,13 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -846,13 +846,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -986,13 +986,13 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -986,13 +986,13 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -1118,11 +1118,11 @@ class AttnDownBlock2D(nn.Module): ...@@ -1118,11 +1118,11 @@ class AttnDownBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -1240,14 +1240,14 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1240,14 +1240,14 @@ class CrossAttnDownBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
additional_residuals: Optional[torch.FloatTensor] = None, additional_residuals: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -1362,8 +1362,8 @@ class DownBlock2D(nn.Module): ...@@ -1362,8 +1362,8 @@ class DownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1465,7 +1465,7 @@ class DownEncoderBlock2D(nn.Module): ...@@ -1465,7 +1465,7 @@ class DownEncoderBlock2D(nn.Module):
else: else:
self.downsamplers = None self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1567,7 +1567,7 @@ class AttnDownEncoderBlock2D(nn.Module): ...@@ -1567,7 +1567,7 @@ class AttnDownEncoderBlock2D(nn.Module):
else: else:
self.downsamplers = None self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1666,12 +1666,12 @@ class AttnSkipDownBlock2D(nn.Module): ...@@ -1666,12 +1666,12 @@ class AttnSkipDownBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
skip_sample: Optional[torch.FloatTensor] = None, skip_sample: Optional[torch.Tensor] = None,
*args, *args,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1757,12 +1757,12 @@ class SkipDownBlock2D(nn.Module): ...@@ -1757,12 +1757,12 @@ class SkipDownBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
skip_sample: Optional[torch.FloatTensor] = None, skip_sample: Optional[torch.Tensor] = None,
*args, *args,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1850,8 +1850,8 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1850,8 +1850,8 @@ class ResnetDownsampleBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1986,13 +1986,13 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1986,13 +1986,13 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -2097,8 +2097,8 @@ class KDownBlock2D(nn.Module): ...@@ -2097,8 +2097,8 @@ class KDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -2201,13 +2201,13 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -2201,13 +2201,13 @@ class KCrossAttnDownBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -2358,13 +2358,13 @@ class AttnUpBlock2D(nn.Module): ...@@ -2358,13 +2358,13 @@ class AttnUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
*args, *args,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -2481,15 +2481,15 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2481,15 +2481,15 @@ class CrossAttnUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -2616,13 +2616,13 @@ class UpBlock2D(nn.Module): ...@@ -2616,13 +2616,13 @@ class UpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
*args, *args,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -2741,7 +2741,7 @@ class UpDecoderBlock2D(nn.Module): ...@@ -2741,7 +2741,7 @@ class UpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb=temb)
...@@ -2839,7 +2839,7 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2839,7 +2839,7 @@ class AttnUpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb=temb)
hidden_states = attn(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb)
...@@ -2947,13 +2947,13 @@ class AttnSkipUpBlock2D(nn.Module): ...@@ -2947,13 +2947,13 @@ class AttnSkipUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
skip_sample=None, skip_sample=None,
*args, *args,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -3059,13 +3059,13 @@ class SkipUpBlock2D(nn.Module): ...@@ -3059,13 +3059,13 @@ class SkipUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
skip_sample=None, skip_sample=None,
*args, *args,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -3166,13 +3166,13 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -3166,13 +3166,13 @@ class ResnetUpsampleBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
*args, *args,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -3310,15 +3310,15 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -3310,15 +3310,15 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -3428,13 +3428,13 @@ class KUpBlock2D(nn.Module): ...@@ -3428,13 +3428,13 @@ class KUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
*args, *args,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -3558,15 +3558,15 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -3558,15 +3558,15 @@ class KCrossAttnUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
res_hidden_states_tuple = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None: if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
...@@ -3684,23 +3684,23 @@ class KAttentionBlock(nn.Module): ...@@ -3684,23 +3684,23 @@ class KAttentionBlock(nn.Module):
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
) )
def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
# TODO: mark emb as non-optional (self.norm2 requires it). # TODO: mark emb as non-optional (self.norm2 requires it).
# requires assessing impact of change to positional param interface. # requires assessing impact of change to positional param interface.
emb: Optional[torch.FloatTensor] = None, emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......
...@@ -60,11 +60,11 @@ class UNet2DConditionOutput(BaseOutput): ...@@ -60,11 +60,11 @@ class UNet2DConditionOutput(BaseOutput):
The output of [`UNet2DConditionModel`]. The output of [`UNet2DConditionModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
""" """
sample: torch.FloatTensor = None sample: torch.Tensor = None
class UNet2DConditionModel( class UNet2DConditionModel(
...@@ -1042,7 +1042,7 @@ class UNet2DConditionModel( ...@@ -1042,7 +1042,7 @@ class UNet2DConditionModel(
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
...@@ -1060,10 +1060,10 @@ class UNet2DConditionModel( ...@@ -1060,10 +1060,10 @@ class UNet2DConditionModel(
The [`UNet2DConditionModel`] forward method. The [`UNet2DConditionModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`. The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`): encoder_hidden_states (`torch.Tensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
class_labels (`torch.Tensor`, *optional*, defaults to `None`): class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
......
...@@ -411,13 +411,13 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -411,13 +411,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1, num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
for attn, temp_attn, resnet, temp_conv in zip( for attn, temp_attn, resnet, temp_conv in zip(
...@@ -544,13 +544,13 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -544,13 +544,13 @@ class CrossAttnDownBlock3D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1, num_frames: int = 1,
cross_attention_kwargs: Dict[str, Any] = None, cross_attention_kwargs: Dict[str, Any] = None,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
# TODO(Patrick, William) - attention mask is not used # TODO(Patrick, William) - attention mask is not used
output_states = () output_states = ()
...@@ -651,10 +651,10 @@ class DownBlock3D(nn.Module): ...@@ -651,10 +651,10 @@ class DownBlock3D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
num_frames: int = 1, num_frames: int = 1,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = () output_states = ()
for resnet, temp_conv in zip(self.resnets, self.temp_convs): for resnet, temp_conv in zip(self.resnets, self.temp_convs):
...@@ -769,15 +769,15 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -769,15 +769,15 @@ class CrossAttnUpBlock3D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1, num_frames: int = 1,
cross_attention_kwargs: Dict[str, Any] = None, cross_attention_kwargs: Dict[str, Any] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -891,12 +891,12 @@ class UpBlock3D(nn.Module): ...@@ -891,12 +891,12 @@ class UpBlock3D(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
num_frames: int = 1, num_frames: int = 1,
) -> torch.FloatTensor: ) -> torch.Tensor:
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -1008,12 +1008,12 @@ class DownBlockMotion(nn.Module): ...@@ -1008,12 +1008,12 @@ class DownBlockMotion(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
num_frames: int = 1, num_frames: int = 1,
*args, *args,
**kwargs, **kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1174,14 +1174,14 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1174,14 +1174,14 @@ class CrossAttnDownBlockMotion(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1, num_frames: int = 1,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals: Optional[torch.FloatTensor] = None, additional_residuals: Optional[torch.Tensor] = None,
): ):
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
...@@ -1357,16 +1357,16 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1357,16 +1357,16 @@ class CrossAttnUpBlockMotion(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1, num_frames: int = 1,
) -> torch.FloatTensor: ) -> torch.Tensor:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -1518,14 +1518,14 @@ class UpBlockMotion(nn.Module): ...@@ -1518,14 +1518,14 @@ class UpBlockMotion(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
upsample_size=None, upsample_size=None,
num_frames: int = 1, num_frames: int = 1,
*args, *args,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1699,14 +1699,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1699,14 +1699,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1, num_frames: int = 1,
) -> torch.FloatTensor: ) -> torch.Tensor:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -1811,8 +1811,8 @@ class MidBlockTemporalDecoder(nn.Module): ...@@ -1811,8 +1811,8 @@ class MidBlockTemporalDecoder(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
image_only_indicator: torch.FloatTensor, image_only_indicator: torch.Tensor,
): ):
hidden_states = self.resnets[0]( hidden_states = self.resnets[0](
hidden_states, hidden_states,
...@@ -1862,9 +1862,9 @@ class UpBlockTemporalDecoder(nn.Module): ...@@ -1862,9 +1862,9 @@ class UpBlockTemporalDecoder(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
image_only_indicator: torch.FloatTensor, image_only_indicator: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet( hidden_states = resnet(
hidden_states, hidden_states,
...@@ -1935,11 +1935,11 @@ class UNetMidBlockSpatioTemporal(nn.Module): ...@@ -1935,11 +1935,11 @@ class UNetMidBlockSpatioTemporal(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
hidden_states = self.resnets[0]( hidden_states = self.resnets[0](
hidden_states, hidden_states,
temb, temb,
...@@ -2031,10 +2031,10 @@ class DownBlockSpatioTemporal(nn.Module): ...@@ -2031,10 +2031,10 @@ class DownBlockSpatioTemporal(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -2141,11 +2141,11 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module): ...@@ -2141,11 +2141,11 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = () output_states = ()
blocks = list(zip(self.resnets, self.attentions)) blocks = list(zip(self.resnets, self.attentions))
...@@ -2240,11 +2240,11 @@ class UpBlockSpatioTemporal(nn.Module): ...@@ -2240,11 +2240,11 @@ class UpBlockSpatioTemporal(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
for resnet in self.resnets: for resnet in self.resnets:
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -2349,12 +2349,12 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module): ...@@ -2349,12 +2349,12 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment