Unverified Commit 841504bb authored by Inigo Goiri's avatar Inigo Goiri Committed by GitHub
Browse files

Add support to pass image embeddings to the WAN I2V pipeline. (#11175)



* Add support to pass image embeddings to the pipeline.



---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent fc7a867a
...@@ -321,9 +321,19 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -321,9 +321,19 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
width, width,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
image_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): if image is not None and image_embeds is not None:
raise ValueError(
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
" only forward one of the two."
)
if image is None and image_embeds is None:
raise ValueError(
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
)
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}") raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
if height % 16 != 0 or width % 16 != 0: if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
...@@ -463,6 +473,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -463,6 +473,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents: Optional[torch.Tensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np", output_type: Optional[str] = "np",
return_dict: bool = True, return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -512,6 +523,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -512,6 +523,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds (`torch.Tensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `negative_prompt` input argument.
image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
image embeddings are generated from the `image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -556,6 +573,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -556,6 +573,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
width, width,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
image_embeds,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
) )
...@@ -599,6 +617,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -599,6 +617,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if negative_prompt_embeds is not None: if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
if image_embeds is None:
image_embeds = self.encode_image(image, device) image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype) image_embeds = image_embeds.to(transformer_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment