Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
......@@ -190,7 +190,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
ref_image: Union[torch.Tensor, PIL.Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......@@ -201,14 +201,14 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
......@@ -335,10 +335,10 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def hacked_basic_transformer_inner_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
......@@ -453,12 +453,12 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def hack_CrossAttnDownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
):
eps = 1e-6
......@@ -549,14 +549,14 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def hacked_CrossAttnUpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
):
eps = 1e-6
# TODO(Patrick, William) - attention mask is not used
......
......@@ -191,7 +191,7 @@ class StableUnCLIPPipeline(DiffusionPipeline):
num_images_per_prompt: int = 1,
prior_num_inference_steps: int = 25,
generator: Optional[torch.Generator] = None,
prior_latents: Optional[torch.FloatTensor] = None,
prior_latents: Optional[torch.Tensor] = None,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None,
prior_guidance_scale: float = 4.0,
......
......@@ -125,7 +125,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
image: Union[torch.Tensor, PIL.Image.Image],
text: str,
height: int = 512,
width: int = 512,
......@@ -135,10 +135,10 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
**kwargs,
):
......@@ -177,7 +177,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
......@@ -189,7 +189,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
......
......@@ -193,8 +193,8 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
latents: Optional[torch.Tensor] = None,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
tile_size: int = 128,
tile_border: int = 32,
......@@ -206,7 +206,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.Tensor`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
......@@ -228,7 +228,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
......
......@@ -207,14 +207,14 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
image: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None,
steps: int = 5,
decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
image_embeddings: Optional[torch.Tensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None,
decoder_latents: Optional[torch.Tensor] = None,
super_res_latents: Optional[torch.Tensor] = None,
decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
......@@ -223,7 +223,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
image (`List[PIL.Image.Image]` or `torch.FloatTensor`):
image (`List[PIL.Image.Image]` or `torch.Tensor`):
The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
configuration of
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
......@@ -242,9 +242,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
image_embeddings (`torch.Tensor`, *optional*):
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
can be passed for tasks like image interpolations. `image` can the be left to `None`.
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder.
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder.
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
......@@ -272,19 +272,19 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
raise AssertionError(
f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"
)
elif isinstance(image, torch.FloatTensor):
elif isinstance(image, torch.Tensor):
if image.shape[0] != 2:
raise AssertionError(
f"Expected 'image' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
f"Expected 'image' to be torch.Tensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
)
elif isinstance(image_embeddings, torch.Tensor):
if image_embeddings.shape[0] != 2:
raise AssertionError(
f"Expected 'image_embeddings' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
f"Expected 'image_embeddings' to be torch.Tensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
)
else:
raise AssertionError(
f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or Torch.FloatTensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or torch.Tensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
)
original_image_embeddings = self._encode_image(
......
......@@ -166,10 +166,10 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
wildcard_option_dict: Dict[str, List[str]] = {},
wildcard_files: List[str] = [],
......@@ -206,7 +206,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
......@@ -218,7 +218,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
......
......@@ -336,7 +336,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
`torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
......
......@@ -314,7 +314,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
`torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
......
......@@ -406,7 +406,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
`torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
......
......@@ -126,7 +126,7 @@ def get_karras_sigmas(
return sigmas
def get_discretized_lognormal_weights(noise_levels: torch.FloatTensor, p_mean: float = -1.1, p_std: float = 2.0):
def get_discretized_lognormal_weights(noise_levels: torch.Tensor, p_mean: float = -1.1, p_std: float = 2.0):
"""
Calculates the unnormalized weights for a 1D array of noise level sigma_i based on the discretized lognormal"
" distribution used in the iCT paper (given in Equation 10).
......@@ -137,14 +137,14 @@ def get_discretized_lognormal_weights(noise_levels: torch.FloatTensor, p_mean: f
return weights
def get_loss_weighting_schedule(noise_levels: torch.FloatTensor):
def get_loss_weighting_schedule(noise_levels: torch.Tensor):
"""
Calculates the loss weighting schedule lambda given a set of noise levels.
"""
return 1.0 / (noise_levels[1:] - noise_levels[:-1])
def add_noise(original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor):
def add_noise(original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
# Make sure timesteps (Karras sigmas) have the same device and dtype as original_samples
sigmas = timesteps.to(device=original_samples.device, dtype=original_samples.dtype)
while len(sigmas.shape) < len(original_samples.shape):
......
......@@ -737,11 +737,11 @@
"class MoleculeGNNOutput(BaseOutput):\n",
" \"\"\"\n",
" Args:\n",
" sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n",
" sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n",
" Hidden states output. Output of last layer of model.\n",
" \"\"\"\n",
"\n",
" sample: torch.FloatTensor\n",
" sample: torch.Tensor\n",
"\n",
"\n",
"class MultiLayerPerceptron(nn.Module):\n",
......@@ -1354,7 +1354,7 @@
" r\"\"\"\n",
" Args:\n",
" sample: packed torch geometric object\n",
" timestep (`torch.FloatTensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
" timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
" return_dict (`bool`, *optional*, defaults to `True`):\n",
" Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n",
" Returns:\n",
......@@ -1404,7 +1404,7 @@
" if not return_dict:\n",
" return (-eps_pos,)\n",
"\n",
" return MoleculeGNNOutput(sample=torch.FloatTensor(-eps_pos).to(pos.device))"
" return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))"
],
"metadata": {
"id": "MCeZA1qQXzoK"
......
......@@ -279,8 +279,8 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
......@@ -312,8 +312,8 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
......@@ -333,10 +333,10 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
......@@ -852,7 +852,7 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
`torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
......@@ -906,9 +906,9 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
......@@ -928,10 +928,10 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
`init`, images must be passed as a list such that each element of the list can be correctly batched for
......@@ -963,14 +963,14 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
......@@ -981,7 +981,7 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
......
......@@ -181,11 +181,11 @@ class PromptDiffusionControlNetModel(ControlNetModel):
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
controlnet_query_cond: torch.FloatTensor,
controlnet_cond: torch.Tensor,
controlnet_query_cond: torch.Tensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
......@@ -194,20 +194,20 @@ class PromptDiffusionControlNetModel(ControlNetModel):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
"""
The [`~PromptDiffusionControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
sample (`torch.Tensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_query_cond (`torch.FloatTensor`):
controlnet_query_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
......
......@@ -163,11 +163,11 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: Optional[int] = 1,
knn: Optional[int] = 10,
**kwargs,
......@@ -199,11 +199,11 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
......@@ -213,7 +213,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
......
......@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]):
return images
def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.FloatTensor:
def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.Tensor:
"""
Preprocesses a list of images into a batch of tensors.
......@@ -29,7 +29,7 @@ def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtr
A list of images to preprocess.
Returns:
:obj:`torch.FloatTensor`: A batch of tensors.
:obj:`torch.Tensor`: A batch of tensors.
"""
images = [np.array(image) for image in images]
images = [(image + 1.0) / 2.0 for image in images]
......
......@@ -17,109 +17,99 @@ MODEL = "base_with_context"
def load_notes_encoder(weights, model):
model.token_embedder.weight = nn.Parameter(torch.FloatTensor(weights["token_embedder"]["embedding"]))
model.position_encoding.weight = nn.Parameter(
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
)
model.token_embedder.weight = nn.Parameter(torch.Tensor(weights["token_embedder"]["embedding"]))
model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
for lyr_num, lyr in enumerate(model.encoders):
ly_weight = weights[f"layers_{lyr_num}"]
lyr.layer[0].layer_norm.weight = nn.Parameter(
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
)
lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"]))
attention_weights = ly_weight["attention"]
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))
model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"]))
return model
def load_continuous_encoder(weights, model):
model.input_proj.weight = nn.Parameter(torch.FloatTensor(weights["input_proj"]["kernel"].T))
model.input_proj.weight = nn.Parameter(torch.Tensor(weights["input_proj"]["kernel"].T))
model.position_encoding.weight = nn.Parameter(
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
)
model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
for lyr_num, lyr in enumerate(model.encoders):
ly_weight = weights[f"layers_{lyr_num}"]
attention_weights = ly_weight["attention"]
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
lyr.layer[0].layer_norm.weight = nn.Parameter(
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
)
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"]))
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))
model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"]))
return model
def load_decoder(weights, model):
model.conditioning_emb[0].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense0"]["kernel"].T))
model.conditioning_emb[2].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense1"]["kernel"].T))
model.conditioning_emb[0].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense0"]["kernel"].T))
model.conditioning_emb[2].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense1"]["kernel"].T))
model.position_encoding.weight = nn.Parameter(
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
)
model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
model.continuous_inputs_projection.weight = nn.Parameter(
torch.FloatTensor(weights["continuous_inputs_projection"]["kernel"].T)
torch.Tensor(weights["continuous_inputs_projection"]["kernel"].T)
)
for lyr_num, lyr in enumerate(model.decoders):
ly_weight = weights[f"layers_{lyr_num}"]
lyr.layer[0].layer_norm.weight = nn.Parameter(
torch.FloatTensor(ly_weight["pre_self_attention_layer_norm"]["scale"])
torch.Tensor(ly_weight["pre_self_attention_layer_norm"]["scale"])
)
lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter(
torch.FloatTensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)
torch.Tensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)
)
attention_weights = ly_weight["self_attention"]
lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
attention_weights = ly_weight["MultiHeadDotProductAttention_0"]
lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
lyr.layer[1].layer_norm.weight = nn.Parameter(
torch.FloatTensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])
torch.Tensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])
)
lyr.layer[2].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
lyr.layer[2].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
lyr.layer[2].film.scale_bias.weight = nn.Parameter(
torch.FloatTensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)
torch.Tensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)
)
lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
model.decoder_norm.weight = nn.Parameter(torch.FloatTensor(weights["decoder_norm"]["scale"]))
model.decoder_norm.weight = nn.Parameter(torch.Tensor(weights["decoder_norm"]["scale"]))
model.spec_out.weight = nn.Parameter(torch.FloatTensor(weights["spec_out_dense"]["kernel"].T))
model.spec_out.weight = nn.Parameter(torch.Tensor(weights["spec_out_dense"]["kernel"].T))
return model
......
......@@ -282,15 +282,15 @@ class BasicTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
......@@ -477,10 +477,10 @@ class TemporalBasicTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
hidden_states: torch.Tensor,
num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
......
......@@ -503,9 +503,9 @@ class Attention(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
r"""
......@@ -751,10 +751,10 @@ class AttnProcessor:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
......@@ -863,9 +863,9 @@ class CustomDiffusionAttnProcessor(nn.Module):
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
......@@ -928,9 +928,9 @@ class AttnAddedKVProcessor:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
......@@ -1001,9 +1001,9 @@ class AttnAddedKVProcessor2_0:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
......@@ -1080,9 +1080,9 @@ class XFormersAttnAddedKVProcessor:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
......@@ -1151,13 +1151,13 @@ class XFormersAttnProcessor:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1243,13 +1243,13 @@ class AttnProcessorNPU:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1349,13 +1349,13 @@ class AttnProcessor2_0:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1448,13 +1448,13 @@ class FusedAttnProcessor2_0:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
......@@ -1581,10 +1581,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
......@@ -1692,10 +1692,10 @@ class CustomDiffusionAttnProcessor2_0(nn.Module):
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
......@@ -1773,10 +1773,10 @@ class SlicedAttnProcessor:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
input_ndim = hidden_states.ndim
......@@ -1860,11 +1860,11 @@ class SlicedAttnAddedKVProcessor:
def __call__(
self,
attn: "Attention",
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
......@@ -1957,7 +1957,7 @@ class SpatialNorm(nn.Module):
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
norm_f = self.norm_layer(f)
......@@ -2003,7 +2003,7 @@ class LoRAAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
......@@ -2064,7 +2064,7 @@ class LoRAAttnProcessor2_0(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
......@@ -2143,7 +2143,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
......@@ -2202,7 +2202,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
......@@ -2264,12 +2264,12 @@ class IPAdapterAttnProcessor(nn.Module):
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
):
residual = hidden_states
......@@ -2467,12 +2467,12 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
):
residual = hidden_states
......
......@@ -112,9 +112,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.register_to_config(force_upcast=False)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
......@@ -126,11 +124,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
def _decode(
self,
z: torch.FloatTensor,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
z: torch.Tensor,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
z = self.post_quant_conv(z)
dec = self.decoder(z, image, mask)
......@@ -142,12 +140,12 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
z: torch.Tensor,
generator: Optional[torch.Generator] = None,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
decoded = self._decode(z, image, mask).sample
if not return_dict:
......@@ -157,16 +155,16 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
def forward(
self,
sample: torch.FloatTensor,
mask: Optional[torch.FloatTensor] = None,
sample: torch.Tensor,
mask: Optional[torch.Tensor] = None,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
sample (`torch.Tensor`): Input sample.
mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
......
......@@ -237,13 +237,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
......@@ -268,7 +268,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
......@@ -281,14 +281,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
......@@ -321,7 +319,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
......@@ -331,7 +329,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
......@@ -375,12 +373,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
......@@ -425,14 +423,14 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def forward(
self,
sample: torch.FloatTensor,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment