Unverified Commit be4afa0b authored by Mark Van Aken's avatar Mark Van Aken Committed by GitHub
Browse files

#7535 Update FloatTensor type hints to Tensor (#7883)

* find & replace all FloatTensors to Tensor

* apply formatting

* Update torch.FloatTensor to torch.Tensor in the remaining files

* formatting

* Fix the rest of the places where FloatTensor is used as well as in documentation

* formatting

* Update new file from FloatTensor to Tensor
parent 04f4bd54
...@@ -392,7 +392,7 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -392,7 +392,7 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at The frequency at which the `callback` function is called. If not specified, the callback is called at
every step. every step.
...@@ -529,12 +529,12 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -529,12 +529,12 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
num_videos_per_prompt: Optional[int] = 1, num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
motion_field_strength_x: float = 12, motion_field_strength_x: float = 12,
motion_field_strength_y: float = 12, motion_field_strength_y: float = 12,
output_type: Optional[str] = "tensor", output_type: Optional[str] = "tensor",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
t0: int = 44, t0: int = 44,
t1: int = 47, t1: int = 47,
...@@ -569,7 +569,7 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -569,7 +569,7 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic. generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. tensor is generated by sampling using the supplied random `generator`.
...@@ -581,7 +581,7 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -581,7 +581,7 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
a plain tuple. a plain tuple.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at The frequency at which the `callback` function is called. If not specified, the callback is called at
every step. every step.
...@@ -795,8 +795,8 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -795,8 +795,8 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
): ):
...@@ -816,10 +816,10 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -816,10 +816,10 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
......
...@@ -581,10 +581,10 @@ class TextToVideoZeroSDXLPipeline( ...@@ -581,10 +581,10 @@ class TextToVideoZeroSDXLPipeline(
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None, negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
): ):
...@@ -610,17 +610,17 @@ class TextToVideoZeroSDXLPipeline( ...@@ -610,17 +610,17 @@ class TextToVideoZeroSDXLPipeline(
negative_prompt_2 (`str` or `List[str]`, *optional*): negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument. If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
...@@ -861,7 +861,7 @@ class TextToVideoZeroSDXLPipeline( ...@@ -861,7 +861,7 @@ class TextToVideoZeroSDXLPipeline(
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at The frequency at which the `callback` function is called. If not specified, the callback is called at
every step. every step.
...@@ -933,16 +933,16 @@ class TextToVideoZeroSDXLPipeline( ...@@ -933,16 +933,16 @@ class TextToVideoZeroSDXLPipeline(
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
frame_ids: Optional[List[int]] = None, frame_ids: Optional[List[int]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
motion_field_strength_x: float = 12, motion_field_strength_x: float = 12,
motion_field_strength_y: float = 12, motion_field_strength_y: float = 12,
output_type: Optional[str] = "tensor", output_type: Optional[str] = "tensor",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
...@@ -1002,21 +1002,21 @@ class TextToVideoZeroSDXLPipeline( ...@@ -1002,21 +1002,21 @@ class TextToVideoZeroSDXLPipeline(
frame_ids (`List[int]`, *optional*): frame_ids (`List[int]`, *optional*):
Indexes of the frames that are being generated. This is used when generating longer videos Indexes of the frames that are being generated. This is used when generating longer videos
chunk-by-chunk. chunk-by-chunk.
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument. If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
...@@ -1034,7 +1034,7 @@ class TextToVideoZeroSDXLPipeline( ...@@ -1034,7 +1034,7 @@ class TextToVideoZeroSDXLPipeline(
of a plain tuple. of a plain tuple.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
......
...@@ -217,9 +217,9 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -217,9 +217,9 @@ class UnCLIPPipeline(DiffusionPipeline):
decoder_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7, super_res_num_inference_steps: int = 7,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prior_latents: Optional[torch.FloatTensor] = None, prior_latents: Optional[torch.Tensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.Tensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.Tensor] = None,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.Tensor] = None,
prior_guidance_scale: float = 4.0, prior_guidance_scale: float = 4.0,
...@@ -248,11 +248,11 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -248,11 +248,11 @@ class UnCLIPPipeline(DiffusionPipeline):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic. generation deterministic.
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): prior_latents (`torch.Tensor` of shape (batch size, embeddings dimension), *optional*):
Pre-generated noisy latents to be used as inputs for the prior. Pre-generated noisy latents to be used as inputs for the prior.
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder. Pre-generated noisy latents to be used as inputs for the decoder.
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder. Pre-generated noisy latents to be used as inputs for the decoder.
prior_guidance_scale (`float`, *optional*, defaults to 4.0): prior_guidance_scale (`float`, *optional*, defaults to 4.0):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
......
...@@ -199,13 +199,13 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -199,13 +199,13 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor]] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
decoder_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7, super_res_num_inference_steps: int = 7,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
decoder_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.Tensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.Tensor] = None,
image_embeddings: Optional[torch.Tensor] = None, image_embeddings: Optional[torch.Tensor] = None,
decoder_guidance_scale: float = 8.0, decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
...@@ -215,7 +215,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -215,7 +215,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
The call function to the pipeline for generation. The call function to the pipeline for generation.
Args: Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
`Image` or tensor representing an image batch to be used as the starting point. If you provide a `Image` or tensor representing an image batch to be used as the starting point. If you provide a
tensor, it needs to be compatible with the [`CLIPImageProcessor`] tensor, it needs to be compatible with the [`CLIPImageProcessor`]
[configuration](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). [configuration](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
...@@ -231,9 +231,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -231,9 +231,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic. generation deterministic.
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder. Pre-generated noisy latents to be used as inputs for the decoder.
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder. Pre-generated noisy latents to be used as inputs for the decoder.
decoder_guidance_scale (`float`, *optional*, defaults to 4.0): decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
......
...@@ -220,7 +220,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): ...@@ -220,7 +220,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds`
must be supplied. must be supplied.
input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): input_embeds (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
An embedded representation to directly pass to the transformer as a prefix for beam search. One of An embedded representation to directly pass to the transformer as a prefix for beam search. One of
`input_ids` and `input_embeds` must be supplied. `input_ids` and `input_embeds` must be supplied.
device: device:
......
...@@ -739,8 +739,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -739,8 +739,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
""" """
Args: Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention. self-attention.
...@@ -1038,9 +1037,9 @@ class UniDiffuserModel(ModelMixin, ConfigMixin): ...@@ -1038,9 +1037,9 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
latent_image_embeds: torch.FloatTensor, latent_image_embeds: torch.Tensor,
image_embeds: torch.FloatTensor, image_embeds: torch.Tensor,
prompt_embeds: torch.FloatTensor, prompt_embeds: torch.Tensor,
timestep_img: Union[torch.Tensor, float, int], timestep_img: Union[torch.Tensor, float, int],
timestep_text: Union[torch.Tensor, float, int], timestep_text: Union[torch.Tensor, float, int],
data_type: Optional[Union[torch.Tensor, float, int]] = 1, data_type: Optional[Union[torch.Tensor, float, int]] = 1,
...@@ -1049,11 +1048,11 @@ class UniDiffuserModel(ModelMixin, ConfigMixin): ...@@ -1049,11 +1048,11 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
): ):
""" """
Args: Args:
latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`): latent_image_embeds (`torch.Tensor` of shape `(batch size, latent channels, height, width)`):
Latent image representation from the VAE encoder. Latent image representation from the VAE encoder.
image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`): image_embeds (`torch.Tensor` of shape `(batch size, 1, clip_img_dim)`):
CLIP-embedded image representation (unsqueezed in the first dimension). CLIP-embedded image representation (unsqueezed in the first dimension).
prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`): prompt_embeds (`torch.Tensor` of shape `(batch size, seq_len, text_dim)`):
CLIP-embedded text representation. CLIP-embedded text representation.
timestep_img (`torch.long` or `float` or `int`): timestep_img (`torch.long` or `float` or `int`):
Current denoising step for the image. Current denoising step for the image.
......
...@@ -304,7 +304,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -304,7 +304,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
batch_size = 1 batch_size = 1
else: else:
# Image must be available and type either PIL.Image.Image or torch.FloatTensor. # Image must be available and type either PIL.Image.Image or torch.Tensor.
# Not currently supporting something like image_embeds. # Not currently supporting something like image_embeds.
batch_size = image.shape[0] batch_size = image.shape[0]
multiplier = num_prompts_per_image multiplier = num_prompts_per_image
...@@ -353,8 +353,8 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -353,8 +353,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
**kwargs, **kwargs,
): ):
...@@ -386,8 +386,8 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -386,8 +386,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
): ):
...@@ -407,10 +407,10 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -407,10 +407,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
...@@ -1080,7 +1080,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1080,7 +1080,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Optional[Union[str, List[str]]] = None, prompt: Optional[Union[str, List[str]]] = None,
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
data_type: Optional[int] = 1, data_type: Optional[int] = 1,
...@@ -1091,15 +1091,15 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1091,15 +1091,15 @@ class UniDiffuserPipeline(DiffusionPipeline):
num_prompts_per_image: Optional[int] = 1, num_prompts_per_image: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
prompt_latents: Optional[torch.FloatTensor] = None, prompt_latents: Optional[torch.Tensor] = None,
vae_latents: Optional[torch.FloatTensor] = None, vae_latents: Optional[torch.Tensor] = None,
clip_latents: Optional[torch.FloatTensor] = None, clip_latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
): ):
r""" r"""
...@@ -1109,7 +1109,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1109,7 +1109,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
Required for text-conditioned image generation (`text2img`) mode. Required for text-conditioned image generation (`text2img`) mode.
image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): image (`torch.Tensor` or `PIL.Image.Image`, *optional*):
`Image` or tensor representing an image batch. Required for image-conditioned text generation `Image` or tensor representing an image batch. Required for image-conditioned text generation
(`img2text`) mode. (`img2text`) mode.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
...@@ -1144,29 +1144,29 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1144,29 +1144,29 @@ class UniDiffuserPipeline(DiffusionPipeline):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic. generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for joint Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for joint
image-text generation. Can be used to tweak the same generation with different prompts. If not image-text generation. Can be used to tweak the same generation with different prompts. If not
provided, a latents tensor is generated by sampling using the supplied random `generator`. This assumes provided, a latents tensor is generated by sampling using the supplied random `generator`. This assumes
a full set of VAE, CLIP, and text latents, if supplied, overrides the value of `prompt_latents`, a full set of VAE, CLIP, and text latents, if supplied, overrides the value of `prompt_latents`,
`vae_latents`, and `clip_latents`. `vae_latents`, and `clip_latents`.
prompt_latents (`torch.FloatTensor`, *optional*): prompt_latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for text Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for text
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. tensor is generated by sampling using the supplied random `generator`.
vae_latents (`torch.FloatTensor`, *optional*): vae_latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. tensor is generated by sampling using the supplied random `generator`.
clip_latents (`torch.FloatTensor`, *optional*): clip_latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. Used in text-conditioned provided, text embeddings are generated from the `prompt` input argument. Used in text-conditioned
image generation (`text2img`) mode. image generation (`text2img`) mode.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are be generated from the `negative_prompt` input argument. Used not provided, `negative_prompt_embeds` are be generated from the `negative_prompt` input argument. Used
in text-conditioned image generation (`text2img`) mode. in text-conditioned image generation (`text2img`) mode.
...@@ -1176,7 +1176,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1176,7 +1176,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
Whether or not to return a [`~pipelines.ImageTextPipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImageTextPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at The frequency at which the `callback` function is called. If not specified, the callback is called at
every step. every step.
......
...@@ -130,7 +130,7 @@ class PaellaVQModel(ModelMixin, ConfigMixin): ...@@ -130,7 +130,7 @@ class PaellaVQModel(ModelMixin, ConfigMixin):
) )
@apply_forward_hook @apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
h = self.in_block(x) h = self.in_block(x)
h = self.down_blocks(h) h = self.down_blocks(h)
...@@ -141,8 +141,8 @@ class PaellaVQModel(ModelMixin, ConfigMixin): ...@@ -141,8 +141,8 @@ class PaellaVQModel(ModelMixin, ConfigMixin):
@apply_forward_hook @apply_forward_hook
def decode( def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True self, h: torch.Tensor, force_not_quantize: bool = True, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, torch.Tensor]:
if not force_not_quantize: if not force_not_quantize:
quant, _, _ = self.vquantizer(h) quant, _, _ = self.vquantizer(h)
else: else:
...@@ -155,10 +155,10 @@ class PaellaVQModel(ModelMixin, ConfigMixin): ...@@ -155,10 +155,10 @@ class PaellaVQModel(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def forward(self, sample: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
""" """
......
...@@ -209,7 +209,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -209,7 +209,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]], image_embeddings: Union[torch.Tensor, List[torch.Tensor]],
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
num_inference_steps: int = 12, num_inference_steps: int = 12,
timesteps: Optional[List[float]] = None, timesteps: Optional[List[float]] = None,
...@@ -217,7 +217,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -217,7 +217,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
...@@ -228,7 +228,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -228,7 +228,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`): image_embedding (`torch.Tensor` or `List[torch.Tensor]`):
Image Embeddings either extracted from an image or generated by a Prior Model. Image Embeddings either extracted from an image or generated by a Prior Model.
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation.
...@@ -252,7 +252,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -252,7 +252,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
......
...@@ -154,11 +154,11 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -154,11 +154,11 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
decoder_timesteps: Optional[List[float]] = None, decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0, decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
...@@ -176,10 +176,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -176,10 +176,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`). if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument. weighting. If not provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
...@@ -218,7 +218,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -218,7 +218,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
......
...@@ -54,12 +54,12 @@ class WuerstchenPriorPipelineOutput(BaseOutput): ...@@ -54,12 +54,12 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
Output class for WuerstchenPriorPipeline. Output class for WuerstchenPriorPipeline.
Args: Args:
image_embeddings (`torch.FloatTensor` or `np.ndarray`) image_embeddings (`torch.Tensor` or `np.ndarray`)
Prior image embeddings for text prompt Prior image embeddings for text prompt
""" """
image_embeddings: Union[torch.FloatTensor, np.ndarray] image_embeddings: Union[torch.Tensor, np.ndarray]
class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
...@@ -136,8 +136,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -136,8 +136,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
do_classifier_free_guidance, do_classifier_free_guidance,
prompt=None, prompt=None,
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
): ):
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -288,11 +288,11 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -288,11 +288,11 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
timesteps: List[float] = None, timesteps: List[float] = None,
guidance_scale: float = 8.0, guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pt", output_type: Optional[str] = "pt",
return_dict: bool = True, return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
...@@ -324,10 +324,10 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -324,10 +324,10 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`). if `decoder_guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
...@@ -336,7 +336,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -336,7 +336,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
......
...@@ -31,19 +31,19 @@ class KarrasVeOutput(BaseOutput): ...@@ -31,19 +31,19 @@ class KarrasVeOutput(BaseOutput):
Output class for the scheduler's step function output. Output class for the scheduler's step function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): derivative (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivative of predicted original image sample (x_0). Derivative of predicted original image sample (x_0).
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep. The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
derivative: torch.FloatTensor derivative: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
class KarrasVeScheduler(SchedulerMixin, ConfigMixin): class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
...@@ -94,21 +94,21 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -94,21 +94,21 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps: int = None self.num_inference_steps: int = None
self.timesteps: np.IntTensor = None self.timesteps: np.IntTensor = None
self.schedule: torch.FloatTensor = None # sigma(t_i) self.schedule: torch.Tensor = None # sigma(t_i)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -136,14 +136,14 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -136,14 +136,14 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
def add_noise_to_input( def add_noise_to_input(
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[torch.FloatTensor, float]: ) -> Tuple[torch.Tensor, float]:
""" """
Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a
higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`. higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
sigma (`float`): sigma (`float`):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
...@@ -163,10 +163,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -163,10 +163,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
sigma_hat: float, sigma_hat: float,
sigma_prev: float, sigma_prev: float,
sample_hat: torch.FloatTensor, sample_hat: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]: ) -> Union[KarrasVeOutput, Tuple]:
""" """
...@@ -174,11 +174,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -174,11 +174,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
sigma_hat (`float`): sigma_hat (`float`):
sigma_prev (`float`): sigma_prev (`float`):
sample_hat (`torch.FloatTensor`): sample_hat (`torch.Tensor`):
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`.
...@@ -202,25 +202,25 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -202,25 +202,25 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def step_correct( def step_correct(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
sigma_hat: float, sigma_hat: float,
sigma_prev: float, sigma_prev: float,
sample_hat: torch.FloatTensor, sample_hat: torch.Tensor,
sample_prev: torch.FloatTensor, sample_prev: torch.Tensor,
derivative: torch.FloatTensor, derivative: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]: ) -> Union[KarrasVeOutput, Tuple]:
""" """
Corrects the predicted sample based on the `model_output` of the network. Corrects the predicted sample based on the `model_output` of the network.
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
sigma_hat (`float`): TODO sigma_hat (`float`): TODO
sigma_prev (`float`): TODO sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor`): TODO sample_hat (`torch.Tensor`): TODO
sample_prev (`torch.FloatTensor`): TODO sample_prev (`torch.Tensor`): TODO
derivative (`torch.FloatTensor`): TODO derivative (`torch.Tensor`): TODO
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
......
...@@ -29,16 +29,16 @@ class AmusedSchedulerOutput(BaseOutput): ...@@ -29,16 +29,16 @@ class AmusedSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: torch.FloatTensor = None pred_original_sample: torch.Tensor = None
class AmusedScheduler(SchedulerMixin, ConfigMixin): class AmusedScheduler(SchedulerMixin, ConfigMixin):
...@@ -70,7 +70,7 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin): ...@@ -70,7 +70,7 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: torch.long, timestep: torch.long,
sample: torch.LongTensor, sample: torch.LongTensor,
starting_mask_ratio: int = 1, starting_mask_ratio: int = 1,
......
...@@ -61,12 +61,12 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput): ...@@ -61,12 +61,12 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function. Output class for the scheduler's `step` function.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin): class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
...@@ -113,28 +113,28 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,28 +113,28 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
def init_noise_sigma(self): def init_noise_sigma(self):
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]] return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample * self.c_in[timestep] return sample * self.c_in[timestep]
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: torch.FloatTensor, sample: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]: ) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
...@@ -143,11 +143,11 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin): ...@@ -143,11 +143,11 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`float`): timestep (`float`):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
......
...@@ -33,12 +33,12 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput): ...@@ -33,12 +33,12 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function. Output class for the scheduler's `step` function.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
...@@ -126,20 +126,18 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -126,20 +126,18 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def scale_model_input( def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
""" """
Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`. Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`float` or `torch.FloatTensor`): timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
# Get sigma corresponding to timestep # Get sigma corresponding to timestep
...@@ -278,7 +276,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -278,7 +276,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
</Tip> </Tip>
Args: Args:
sigma (`torch.FloatTensor`): sigma (`torch.Tensor`):
The current sigma in the Karras sigma schedule. The current sigma in the Karras sigma schedule.
Returns: Returns:
...@@ -319,9 +317,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -319,9 +317,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: Union[float, torch.FloatTensor], timestep: Union[float, torch.Tensor],
sample: torch.FloatTensor, sample: torch.Tensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]:
...@@ -330,11 +328,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -330,11 +328,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`float`): timestep (`float`):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -417,10 +415,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -417,10 +415,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.FloatTensor, timesteps: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
......
...@@ -35,16 +35,16 @@ class DDIMSchedulerOutput(BaseOutput): ...@@ -35,16 +35,16 @@ class DDIMSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -98,11 +98,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -98,11 +98,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -233,19 +233,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -233,19 +233,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -261,7 +261,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -261,7 +261,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -341,13 +341,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -341,13 +341,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
eta: float = 0.0, eta: float = 0.0,
use_clipped_model_output: bool = False, use_clipped_model_output: bool = False,
generator=None, generator=None,
variance_noise: Optional[torch.FloatTensor] = None, variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]: ) -> Union[DDIMSchedulerOutput, Tuple]:
""" """
...@@ -355,11 +355,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -355,11 +355,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
eta (`float`): eta (`float`):
The weight of noise for added noise in diffusion step. The weight of noise for added noise in diffusion step.
...@@ -370,7 +370,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -370,7 +370,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
`use_clipped_model_output` has no effect. `use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
variance_noise (`torch.FloatTensor`): variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`]. itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -470,10 +470,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -470,10 +470,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
...@@ -495,9 +495,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -495,9 +495,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity( def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
......
...@@ -33,16 +33,16 @@ class DDIMSchedulerOutput(BaseOutput): ...@@ -33,16 +33,16 @@ class DDIMSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -97,11 +97,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -97,11 +97,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -231,19 +231,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -231,19 +231,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -288,9 +288,9 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -288,9 +288,9 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]: ) -> Union[DDIMSchedulerOutput, Tuple]:
""" """
...@@ -298,11 +298,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -298,11 +298,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
eta (`float`): eta (`float`):
The weight of noise for added noise in diffusion step. The weight of noise for added noise in diffusion step.
...@@ -311,7 +311,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -311,7 +311,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect. `use_clipped_model_output` has no effect.
variance_noise (`torch.FloatTensor`): variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`]. itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
......
...@@ -35,16 +35,16 @@ class DDIMParallelSchedulerOutput(BaseOutput): ...@@ -35,16 +35,16 @@ class DDIMParallelSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -241,19 +241,19 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -241,19 +241,19 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -283,7 +283,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -283,7 +283,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -364,13 +364,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -364,13 +364,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
eta: float = 0.0, eta: float = 0.0,
use_clipped_model_output: bool = False, use_clipped_model_output: bool = False,
generator=None, generator=None,
variance_noise: Optional[torch.FloatTensor] = None, variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDIMParallelSchedulerOutput, Tuple]: ) -> Union[DDIMParallelSchedulerOutput, Tuple]:
""" """
...@@ -378,9 +378,9 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -378,9 +378,9 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.Tensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step. eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
...@@ -388,7 +388,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -388,7 +388,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect. coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator. generator: random number generator.
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://arxiv.org/abs/2210.05559) CycleDiffusion. (https://arxiv.org/abs/2210.05559)
return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class
...@@ -486,12 +486,12 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -486,12 +486,12 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise( def batch_step_no_noise(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timesteps: List[int], timesteps: List[int],
sample: torch.FloatTensor, sample: torch.Tensor,
eta: float = 0.0, eta: float = 0.0,
use_clipped_model_output: bool = False, use_clipped_model_output: bool = False,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
...@@ -501,10 +501,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -501,10 +501,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`List[int]`): timesteps (`List[int]`):
current discrete timesteps in the diffusion chain. This is now a list of integers. current discrete timesteps in the diffusion chain. This is now a list of integers.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step. eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
...@@ -513,7 +513,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -513,7 +513,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
coincide with the one provided as input and `use_clipped_model_output` will have not effect. coincide with the one provided as input and `use_clipped_model_output` will have not effect.
Returns: Returns:
`torch.FloatTensor`: sample tensor at previous timestep. `torch.Tensor`: sample tensor at previous timestep.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
...@@ -595,10 +595,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -595,10 +595,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
...@@ -620,9 +620,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -620,9 +620,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity( def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
......
...@@ -33,16 +33,16 @@ class DDPMSchedulerOutput(BaseOutput): ...@@ -33,16 +33,16 @@ class DDPMSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
def betas_for_alpha_bar( def betas_for_alpha_bar(
...@@ -96,11 +96,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -96,11 +96,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -231,19 +231,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -231,19 +231,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -363,7 +363,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -363,7 +363,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -398,9 +398,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -398,9 +398,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]: ) -> Union[DDPMSchedulerOutput, Tuple]:
...@@ -409,11 +409,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -409,11 +409,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`float`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
...@@ -498,10 +498,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -498,10 +498,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
...@@ -522,9 +522,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -522,9 +522,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
def get_velocity( def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
......
...@@ -34,16 +34,16 @@ class DDPMParallelSchedulerOutput(BaseOutput): ...@@ -34,16 +34,16 @@ class DDPMParallelSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
Args: Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop. denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance. `pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.Tensor
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
...@@ -98,11 +98,11 @@ def rescale_zero_terminal_snr(betas): ...@@ -98,11 +98,11 @@ def rescale_zero_terminal_snr(betas):
Args: Args:
betas (`torch.FloatTensor`): betas (`torch.Tensor`):
the betas that the scheduler is being initialized with. the betas that the scheduler is being initialized with.
Returns: Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR `torch.Tensor`: rescaled betas with zero terminal SNR
""" """
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas alphas = 1.0 - betas
...@@ -240,19 +240,19 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -240,19 +240,19 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. current timestep.
Args: Args:
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
...@@ -375,7 +375,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -375,7 +375,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
...@@ -410,9 +410,9 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -410,9 +410,9 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.Tensor,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDPMParallelSchedulerOutput, Tuple]: ) -> Union[DDPMParallelSchedulerOutput, Tuple]:
...@@ -421,9 +421,9 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -421,9 +421,9 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.Tensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
...@@ -506,10 +506,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -506,10 +506,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise( def batch_step_no_noise(
self, self,
model_output: torch.FloatTensor, model_output: torch.Tensor,
timesteps: List[int], timesteps: List[int],
sample: torch.FloatTensor, sample: torch.Tensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
...@@ -519,14 +519,14 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -519,14 +519,14 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`List[int]`): timesteps (`List[int]`):
current discrete timesteps in the diffusion chain. This is now a list of integers. current discrete timesteps in the diffusion chain. This is now a list of integers.
sample (`torch.FloatTensor`): sample (`torch.Tensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
Returns: Returns:
`torch.FloatTensor`: sample tensor at previous timestep. `torch.Tensor`: sample tensor at previous timestep.
""" """
t = timesteps t = timesteps
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
...@@ -587,10 +587,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -587,10 +587,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.Tensor,
noise: torch.FloatTensor, noise: torch.Tensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls # for the subsequent add_noise calls
...@@ -612,9 +612,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -612,9 +612,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity( def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment