Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
...@@ -190,7 +190,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -190,7 +190,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, ref_image: Union[torch.Tensor, PIL.Image.Image] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -201,14 +201,14 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -201,14 +201,14 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
...@@ -335,10 +335,10 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -335,10 +335,10 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def hacked_basic_transformer_inner_forward( def hacked_basic_transformer_inner_forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None, cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None,
...@@ -453,12 +453,12 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -453,12 +453,12 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def hack_CrossAttnDownBlock2D_forward( def hack_CrossAttnDownBlock2D_forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
): ):
eps = 1e-6 eps = 1e-6
...@@ -549,14 +549,14 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -549,14 +549,14 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def hacked_CrossAttnUpBlock2D_forward( def hacked_CrossAttnUpBlock2D_forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
): ):
eps = 1e-6 eps = 1e-6
# TODO(Patrick, William) - attention mask is not used # TODO(Patrick, William) - attention mask is not used
......
...@@ -191,7 +191,7 @@ class StableUnCLIPPipeline(DiffusionPipeline): ...@@ -191,7 +191,7 @@ class StableUnCLIPPipeline(DiffusionPipeline):
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
prior_num_inference_steps: int = 25, prior_num_inference_steps: int = 25,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
prior_latents: Optional[torch.FloatTensor] = None, prior_latents: Optional[torch.Tensor] = None,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.Tensor] = None,
prior_guidance_scale: float = 4.0, prior_guidance_scale: float = 4.0,
......
...@@ -125,7 +125,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin): ...@@ -125,7 +125,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image], image: Union[torch.Tensor, PIL.Image.Image],
text: str, text: str,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
...@@ -135,10 +135,10 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin): ...@@ -135,10 +135,10 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
**kwargs, **kwargs,
): ):
...@@ -177,7 +177,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin): ...@@ -177,7 +177,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic. deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
...@@ -189,7 +189,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin): ...@@ -189,7 +189,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
plain tuple. plain tuple.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
......
...@@ -193,8 +193,8 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline): ...@@ -193,8 +193,8 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
tile_size: int = 128, tile_size: int = 128,
tile_border: int = 32, tile_border: int = 32,
...@@ -206,7 +206,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline): ...@@ -206,7 +206,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.Tensor`):
`Image`, or tensor representing an image batch which will be upscaled. * `Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...@@ -228,7 +228,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline): ...@@ -228,7 +228,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic. deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
......
...@@ -207,14 +207,14 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -207,14 +207,14 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
image: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None, image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None,
steps: int = 5, steps: int = 5,
decoder_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7, super_res_num_inference_steps: int = 7,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
image_embeddings: Optional[torch.Tensor] = None, image_embeddings: Optional[torch.Tensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.Tensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.Tensor] = None,
decoder_guidance_scale: float = 8.0, decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
...@@ -223,7 +223,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -223,7 +223,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
image (`List[PIL.Image.Image]` or `torch.FloatTensor`): image (`List[PIL.Image.Image]` or `torch.Tensor`):
The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
configuration of configuration of
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
...@@ -242,9 +242,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -242,9 +242,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
image_embeddings (`torch.Tensor`, *optional*): image_embeddings (`torch.Tensor`, *optional*):
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
can be passed for tasks like image interpolations. `image` can the be left to `None`. can be passed for tasks like image interpolations. `image` can the be left to `None`.
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder. Pre-generated noisy latents to be used as inputs for the decoder.
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder. Pre-generated noisy latents to be used as inputs for the decoder.
decoder_guidance_scale (`float`, *optional*, defaults to 4.0): decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
...@@ -272,19 +272,19 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -272,19 +272,19 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
raise AssertionError( raise AssertionError(
f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}" f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"
) )
elif isinstance(image, torch.FloatTensor): elif isinstance(image, torch.Tensor):
if image.shape[0] != 2: if image.shape[0] != 2:
raise AssertionError( raise AssertionError(
f"Expected 'image' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}" f"Expected 'image' to be torch.Tensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
) )
elif isinstance(image_embeddings, torch.Tensor): elif isinstance(image_embeddings, torch.Tensor):
if image_embeddings.shape[0] != 2: if image_embeddings.shape[0] != 2:
raise AssertionError( raise AssertionError(
f"Expected 'image_embeddings' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}" f"Expected 'image_embeddings' to be torch.Tensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
) )
else: else:
raise AssertionError( raise AssertionError(
f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or Torch.FloatTensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively" f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or torch.Tensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
) )
original_image_embeddings = self._encode_image( original_image_embeddings = self._encode_image(
......
...@@ -166,10 +166,10 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -166,10 +166,10 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_option_dict: Dict[str, List[str]] = {},
wildcard_files: List[str] = [], wildcard_files: List[str] = [],
...@@ -206,7 +206,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -206,7 +206,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic. deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
...@@ -218,7 +218,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -218,7 +218,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
plain tuple. plain tuple.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
......
...@@ -336,7 +336,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): ...@@ -336,7 +336,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
data type of the generated embeddings data type of the generated embeddings
Returns: Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
""" """
assert len(w.shape) == 1 assert len(w.shape) == 1
w = w * 1000.0 w = w * 1000.0
......
...@@ -314,7 +314,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): ...@@ -314,7 +314,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
data type of the generated embeddings data type of the generated embeddings
Returns: Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
""" """
assert len(w.shape) == 1 assert len(w.shape) == 1
w = w * 1000.0 w = w * 1000.0
......
...@@ -406,7 +406,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): ...@@ -406,7 +406,7 @@ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
data type of the generated embeddings data type of the generated embeddings
Returns: Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
""" """
assert len(w.shape) == 1 assert len(w.shape) == 1
w = w * 1000.0 w = w * 1000.0
......
...@@ -126,7 +126,7 @@ def get_karras_sigmas( ...@@ -126,7 +126,7 @@ def get_karras_sigmas(
return sigmas return sigmas
def get_discretized_lognormal_weights(noise_levels: torch.FloatTensor, p_mean: float = -1.1, p_std: float = 2.0): def get_discretized_lognormal_weights(noise_levels: torch.Tensor, p_mean: float = -1.1, p_std: float = 2.0):
""" """
Calculates the unnormalized weights for a 1D array of noise level sigma_i based on the discretized lognormal" Calculates the unnormalized weights for a 1D array of noise level sigma_i based on the discretized lognormal"
" distribution used in the iCT paper (given in Equation 10). " distribution used in the iCT paper (given in Equation 10).
...@@ -137,14 +137,14 @@ def get_discretized_lognormal_weights(noise_levels: torch.FloatTensor, p_mean: f ...@@ -137,14 +137,14 @@ def get_discretized_lognormal_weights(noise_levels: torch.FloatTensor, p_mean: f
return weights return weights
def get_loss_weighting_schedule(noise_levels: torch.FloatTensor): def get_loss_weighting_schedule(noise_levels: torch.Tensor):
""" """
Calculates the loss weighting schedule lambda given a set of noise levels. Calculates the loss weighting schedule lambda given a set of noise levels.
""" """
return 1.0 / (noise_levels[1:] - noise_levels[:-1]) return 1.0 / (noise_levels[1:] - noise_levels[:-1])
def add_noise(original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor): def add_noise(original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
# Make sure timesteps (Karras sigmas) have the same device and dtype as original_samples # Make sure timesteps (Karras sigmas) have the same device and dtype as original_samples
sigmas = timesteps.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = timesteps.to(device=original_samples.device, dtype=original_samples.dtype)
while len(sigmas.shape) < len(original_samples.shape): while len(sigmas.shape) < len(original_samples.shape):
......
...@@ -737,11 +737,11 @@ ...@@ -737,11 +737,11 @@
"class MoleculeGNNOutput(BaseOutput):\n", "class MoleculeGNNOutput(BaseOutput):\n",
" \"\"\"\n", " \"\"\"\n",
" Args:\n", " Args:\n",
" sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n", " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n",
" Hidden states output. Output of last layer of model.\n", " Hidden states output. Output of last layer of model.\n",
" \"\"\"\n", " \"\"\"\n",
"\n", "\n",
" sample: torch.FloatTensor\n", " sample: torch.Tensor\n",
"\n", "\n",
"\n", "\n",
"class MultiLayerPerceptron(nn.Module):\n", "class MultiLayerPerceptron(nn.Module):\n",
...@@ -1354,7 +1354,7 @@ ...@@ -1354,7 +1354,7 @@
" r\"\"\"\n", " r\"\"\"\n",
" Args:\n", " Args:\n",
" sample: packed torch geometric object\n", " sample: packed torch geometric object\n",
" timestep (`torch.FloatTensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
" return_dict (`bool`, *optional*, defaults to `True`):\n", " return_dict (`bool`, *optional*, defaults to `True`):\n",
" Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n",
" Returns:\n", " Returns:\n",
...@@ -1404,7 +1404,7 @@ ...@@ -1404,7 +1404,7 @@
" if not return_dict:\n", " if not return_dict:\n",
" return (-eps_pos,)\n", " return (-eps_pos,)\n",
"\n", "\n",
" return MoleculeGNNOutput(sample=torch.FloatTensor(-eps_pos).to(pos.device))" " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))"
], ],
"metadata": { "metadata": {
"id": "MCeZA1qQXzoK" "id": "MCeZA1qQXzoK"
......
...@@ -279,8 +279,8 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -279,8 +279,8 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
**kwargs, **kwargs,
): ):
...@@ -312,8 +312,8 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -312,8 +312,8 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
): ):
...@@ -333,10 +333,10 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -333,10 +333,10 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
...@@ -852,7 +852,7 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -852,7 +852,7 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
data type of the generated embeddings data type of the generated embeddings
Returns: Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
""" """
assert len(w.shape) == 1 assert len(w.shape) == 1
w = w * 1000.0 w = w * 1000.0
...@@ -906,9 +906,9 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -906,9 +906,9 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
...@@ -928,10 +928,10 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -928,10 +928,10 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
`init`, images must be passed as a list such that each element of the list can be correctly batched for `init`, images must be passed as a list such that each element of the list can be correctly batched for
...@@ -963,14 +963,14 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -963,14 +963,14 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic. generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
...@@ -981,7 +981,7 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -981,7 +981,7 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
plain tuple. plain tuple.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at The frequency at which the `callback` function is called. If not specified, the callback is called at
every step. every step.
......
...@@ -181,11 +181,11 @@ class PromptDiffusionControlNetModel(ControlNetModel): ...@@ -181,11 +181,11 @@ class PromptDiffusionControlNetModel(ControlNetModel):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor, controlnet_cond: torch.Tensor,
controlnet_query_cond: torch.FloatTensor, controlnet_query_cond: torch.Tensor,
conditioning_scale: float = 1.0, conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
...@@ -194,20 +194,20 @@ class PromptDiffusionControlNetModel(ControlNetModel): ...@@ -194,20 +194,20 @@ class PromptDiffusionControlNetModel(ControlNetModel):
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False, guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
""" """
The [`~PromptDiffusionControlNetModel`] forward method. The [`~PromptDiffusionControlNetModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The noisy input tensor. The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`): timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input. The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`): encoder_hidden_states (`torch.Tensor`):
The encoder hidden states. The encoder hidden states.
controlnet_cond (`torch.FloatTensor`): controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_query_cond (`torch.FloatTensor`): controlnet_query_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`): conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs. The scale factor for ControlNet outputs.
......
...@@ -163,11 +163,11 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -163,11 +163,11 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
knn: Optional[int] = 10, knn: Optional[int] = 10,
**kwargs, **kwargs,
...@@ -199,11 +199,11 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -199,11 +199,11 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic. deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
...@@ -213,7 +213,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -213,7 +213,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
......
...@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]): ...@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]):
return images return images
def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.FloatTensor: def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.Tensor:
""" """
Preprocesses a list of images into a batch of tensors. Preprocesses a list of images into a batch of tensors.
...@@ -29,7 +29,7 @@ def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtr ...@@ -29,7 +29,7 @@ def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtr
A list of images to preprocess. A list of images to preprocess.
Returns: Returns:
:obj:`torch.FloatTensor`: A batch of tensors. :obj:`torch.Tensor`: A batch of tensors.
""" """
images = [np.array(image) for image in images] images = [np.array(image) for image in images]
images = [(image + 1.0) / 2.0 for image in images] images = [(image + 1.0) / 2.0 for image in images]
......
...@@ -17,109 +17,99 @@ MODEL = "base_with_context" ...@@ -17,109 +17,99 @@ MODEL = "base_with_context"
def load_notes_encoder(weights, model): def load_notes_encoder(weights, model):
model.token_embedder.weight = nn.Parameter(torch.FloatTensor(weights["token_embedder"]["embedding"])) model.token_embedder.weight = nn.Parameter(torch.Tensor(weights["token_embedder"]["embedding"]))
model.position_encoding.weight = nn.Parameter( model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
)
for lyr_num, lyr in enumerate(model.encoders): for lyr_num, lyr in enumerate(model.encoders):
ly_weight = weights[f"layers_{lyr_num}"] ly_weight = weights[f"layers_{lyr_num}"]
lyr.layer[0].layer_norm.weight = nn.Parameter( lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"]))
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
)
attention_weights = ly_weight["attention"] attention_weights = ly_weight["attention"]
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T)) lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T)) lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T)) lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T)) lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"])) lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T)) lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"])) model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"]))
return model return model
def load_continuous_encoder(weights, model): def load_continuous_encoder(weights, model):
model.input_proj.weight = nn.Parameter(torch.FloatTensor(weights["input_proj"]["kernel"].T)) model.input_proj.weight = nn.Parameter(torch.Tensor(weights["input_proj"]["kernel"].T))
model.position_encoding.weight = nn.Parameter( model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
)
for lyr_num, lyr in enumerate(model.encoders): for lyr_num, lyr in enumerate(model.encoders):
ly_weight = weights[f"layers_{lyr_num}"] ly_weight = weights[f"layers_{lyr_num}"]
attention_weights = ly_weight["attention"] attention_weights = ly_weight["attention"]
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T)) lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T)) lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T)) lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T)) lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
lyr.layer[0].layer_norm.weight = nn.Parameter( lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"]))
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
)
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T)) lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"])) lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"])) model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"]))
return model return model
def load_decoder(weights, model): def load_decoder(weights, model):
model.conditioning_emb[0].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense0"]["kernel"].T)) model.conditioning_emb[0].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense0"]["kernel"].T))
model.conditioning_emb[2].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense1"]["kernel"].T)) model.conditioning_emb[2].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense1"]["kernel"].T))
model.position_encoding.weight = nn.Parameter( model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
)
model.continuous_inputs_projection.weight = nn.Parameter( model.continuous_inputs_projection.weight = nn.Parameter(
torch.FloatTensor(weights["continuous_inputs_projection"]["kernel"].T) torch.Tensor(weights["continuous_inputs_projection"]["kernel"].T)
) )
for lyr_num, lyr in enumerate(model.decoders): for lyr_num, lyr in enumerate(model.decoders):
ly_weight = weights[f"layers_{lyr_num}"] ly_weight = weights[f"layers_{lyr_num}"]
lyr.layer[0].layer_norm.weight = nn.Parameter( lyr.layer[0].layer_norm.weight = nn.Parameter(
torch.FloatTensor(ly_weight["pre_self_attention_layer_norm"]["scale"]) torch.Tensor(ly_weight["pre_self_attention_layer_norm"]["scale"])
) )
lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter( lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter(
torch.FloatTensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T) torch.Tensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)
) )
attention_weights = ly_weight["self_attention"] attention_weights = ly_weight["self_attention"]
lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T)) lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T)) lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T)) lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T)) lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
attention_weights = ly_weight["MultiHeadDotProductAttention_0"] attention_weights = ly_weight["MultiHeadDotProductAttention_0"]
lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T)) lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T)) lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T)) lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T)) lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
lyr.layer[1].layer_norm.weight = nn.Parameter( lyr.layer[1].layer_norm.weight = nn.Parameter(
torch.FloatTensor(ly_weight["pre_cross_attention_layer_norm"]["scale"]) torch.Tensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])
) )
lyr.layer[2].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"])) lyr.layer[2].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
lyr.layer[2].film.scale_bias.weight = nn.Parameter( lyr.layer[2].film.scale_bias.weight = nn.Parameter(
torch.FloatTensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T) torch.Tensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)
) )
lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T)) lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
model.decoder_norm.weight = nn.Parameter(torch.FloatTensor(weights["decoder_norm"]["scale"])) model.decoder_norm.weight = nn.Parameter(torch.Tensor(weights["decoder_norm"]["scale"]))
model.spec_out.weight = nn.Parameter(torch.FloatTensor(weights["spec_out_dense"]["kernel"].T)) model.spec_out.weight = nn.Parameter(torch.Tensor(weights["spec_out_dense"]["kernel"].T))
return model return model
......
...@@ -282,15 +282,15 @@ class BasicTransformerBlock(nn.Module): ...@@ -282,15 +282,15 @@ class BasicTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None, cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None: if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
...@@ -477,10 +477,10 @@ class TemporalBasicTransformerBlock(nn.Module): ...@@ -477,10 +477,10 @@ class TemporalBasicTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
num_frames: int, num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks. # Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention # 0. Self-Attention
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
......
This diff is collapsed.
...@@ -112,9 +112,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -112,9 +112,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.register_to_config(force_upcast=False) self.register_to_config(force_upcast=False)
@apply_forward_hook @apply_forward_hook
def encode( def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
...@@ -126,11 +124,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -126,11 +124,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
def _decode( def _decode(
self, self,
z: torch.FloatTensor, z: torch.Tensor,
image: Optional[torch.FloatTensor] = None, image: Optional[torch.Tensor] = None,
mask: Optional[torch.FloatTensor] = None, mask: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
z = self.post_quant_conv(z) z = self.post_quant_conv(z)
dec = self.decoder(z, image, mask) dec = self.decoder(z, image, mask)
...@@ -142,12 +140,12 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -142,12 +140,12 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
@apply_forward_hook @apply_forward_hook
def decode( def decode(
self, self,
z: torch.FloatTensor, z: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
image: Optional[torch.FloatTensor] = None, image: Optional[torch.Tensor] = None,
mask: Optional[torch.FloatTensor] = None, mask: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
decoded = self._decode(z, image, mask).sample decoded = self._decode(z, image, mask).sample
if not return_dict: if not return_dict:
...@@ -157,16 +155,16 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -157,16 +155,16 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
mask: Optional[torch.FloatTensor] = None, mask: Optional[torch.Tensor] = None,
sample_posterior: bool = False, sample_posterior: bool = False,
return_dict: bool = True, return_dict: bool = True,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.Tensor`): Input sample.
mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask. mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask.
sample_posterior (`bool`, *optional*, defaults to `False`): sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior. Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
......
...@@ -237,13 +237,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -237,13 +237,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@apply_forward_hook @apply_forward_hook
def encode( def encode(
self, x: torch.FloatTensor, return_dict: bool = True self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
""" """
Encode a batch of images into latents. Encode a batch of images into latents.
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
...@@ -268,7 +268,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -268,7 +268,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return AutoencoderKLOutput(latent_dist=posterior) return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict) return self.tiled_decode(z, return_dict=return_dict)
...@@ -281,14 +281,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -281,14 +281,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
@apply_forward_hook @apply_forward_hook
def decode( def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
""" """
Decode a batch of images. Decode a batch of images.
Args: Args:
z (`torch.FloatTensor`): Input batch of latent vectors. z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
...@@ -321,7 +319,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -321,7 +319,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder. r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
...@@ -331,7 +329,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -331,7 +329,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
output, but they should be much less noticeable. output, but they should be much less noticeable.
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
...@@ -375,12 +373,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -375,12 +373,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return AutoencoderKLOutput(latent_dist=posterior) return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r""" r"""
Decode a batch of images using a tiled decoder. Decode a batch of images using a tiled decoder.
Args: Args:
z (`torch.FloatTensor`): Input batch of latent vectors. z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
...@@ -425,14 +423,14 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -425,14 +423,14 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.Tensor,
sample_posterior: bool = False, sample_posterior: bool = False,
return_dict: bool = True, return_dict: bool = True,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, torch.Tensor]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`): sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior. Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment