Unverified Commit aa82df52 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[IP Adapters] introduce `ip_adapter_image_embeds` in the SD pipeline call (#6868)



* add: support for passing ip adapter image embeddings

* debugging

* make feature_extractor unloading conditioned on safety_checker

* better condition

* type annotation

* index to look into value slices

* more debugging

* debugging

* serialize embeddings dict

* better conditioning

* remove unnecessary prints.

* Update src/diffusers/loaders/ip_adapter.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* make fix-copies and styling.

* styling and further copy fixing.

* fix: check_inputs call in controlnet sdxl img2img pipeline

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent a11b0f83
...@@ -386,7 +386,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -386,7 +386,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
......
...@@ -438,33 +438,38 @@ class StableDiffusionLDM3DPipeline( ...@@ -438,33 +438,38 @@ class StableDiffusionLDM3DPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds(
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): self, ip_adapter_image, ip_adapter_image_embeds, do_classifier_free_guidance, device, num_images_per_prompt
if not isinstance(ip_adapter_image, list): ):
ip_adapter_image = [ip_adapter_image] if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): ip_adapter_image = [ip_adapter_image]
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = [] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
for single_ip_adapter_image, image_proj_layer in zip( raise ValueError(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
): )
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance: image_embeds = []
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) for single_ip_adapter_image, image_proj_layer in zip(
single_image_embeds = single_image_embeds.to(device) ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds) if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds return image_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -510,6 +515,8 @@ class StableDiffusionLDM3DPipeline( ...@@ -510,6 +515,8 @@ class StableDiffusionLDM3DPipeline(
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
...@@ -553,6 +560,11 @@ class StableDiffusionLDM3DPipeline( ...@@ -553,6 +560,11 @@ class StableDiffusionLDM3DPipeline(
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
...@@ -587,6 +599,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -587,6 +599,7 @@ class StableDiffusionLDM3DPipeline(
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -633,6 +646,9 @@ class StableDiffusionLDM3DPipeline( ...@@ -633,6 +646,9 @@ class StableDiffusionLDM3DPipeline(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters. Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -665,7 +681,15 @@ class StableDiffusionLDM3DPipeline( ...@@ -665,7 +681,15 @@ class StableDiffusionLDM3DPipeline(
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
) )
# 2. Define call parameters # 2. Define call parameters
...@@ -682,9 +706,13 @@ class StableDiffusionLDM3DPipeline( ...@@ -682,9 +706,13 @@ class StableDiffusionLDM3DPipeline(
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt ip_adapter_image,
ip_adapter_image_embeds,
do_classifier_free_guidance,
device,
batch_size * num_images_per_prompt,
) )
# 3. Encode input prompt # 3. Encode input prompt
......
...@@ -397,32 +397,38 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -397,32 +397,38 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list): self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
ip_adapter_image = [ip_adapter_image] ):
if ip_adapter_image_embeds is None:
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): if not isinstance(ip_adapter_image, list):
raise ValueError( ip_adapter_image = [ip_adapter_image]
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = [] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
for single_ip_adapter_image, image_proj_layer in zip( raise ValueError(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
): )
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance: image_embeds = []
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) for single_ip_adapter_image, image_proj_layer in zip(
single_image_embeds = single_image_embeds.to(device) ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds) if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
...@@ -493,6 +499,8 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -493,6 +499,8 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
...@@ -536,6 +544,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -536,6 +544,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
...@@ -592,6 +605,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -592,6 +605,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -643,6 +657,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -643,6 +657,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters. Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -680,7 +697,15 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -680,7 +697,15 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
) )
# 2. Define call parameters # 2. Define call parameters
...@@ -697,9 +722,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -697,9 +722,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
) )
# 3. Encode input prompt # 3. Encode input prompt
......
...@@ -361,7 +361,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, IPAdapterMixin): ...@@ -361,7 +361,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, IPAdapterMixin):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
......
...@@ -462,7 +462,6 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin, ...@@ -462,7 +462,6 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin,
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
......
...@@ -550,32 +550,38 @@ class StableDiffusionXLPipeline( ...@@ -550,32 +550,38 @@ class StableDiffusionXLPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list): self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
ip_adapter_image = [ip_adapter_image] ):
if ip_adapter_image_embeds is None:
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): if not isinstance(ip_adapter_image, list):
raise ValueError( ip_adapter_image = [ip_adapter_image]
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = [] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
for single_ip_adapter_image, image_proj_layer in zip( raise ValueError(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
): )
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance: image_embeds = []
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) for single_ip_adapter_image, image_proj_layer in zip(
single_image_embeds = single_image_embeds.to(device) ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds) if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -609,6 +615,8 @@ class StableDiffusionXLPipeline( ...@@ -609,6 +615,8 @@ class StableDiffusionXLPipeline(
negative_prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
...@@ -675,6 +683,11 @@ class StableDiffusionXLPipeline( ...@@ -675,6 +683,11 @@ class StableDiffusionXLPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
) )
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
...@@ -905,6 +918,7 @@ class StableDiffusionXLPipeline( ...@@ -905,6 +918,7 @@ class StableDiffusionXLPipeline(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -994,6 +1008,9 @@ class StableDiffusionXLPipeline( ...@@ -994,6 +1008,9 @@ class StableDiffusionXLPipeline(
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -1092,6 +1109,8 @@ class StableDiffusionXLPipeline( ...@@ -1092,6 +1109,8 @@ class StableDiffusionXLPipeline(
negative_prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
) )
...@@ -1191,9 +1210,9 @@ class StableDiffusionXLPipeline( ...@@ -1191,9 +1210,9 @@ class StableDiffusionXLPipeline(
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
) )
# 8. Denoising loop # 8. Denoising loop
...@@ -1236,7 +1255,7 @@ class StableDiffusionXLPipeline( ...@@ -1236,7 +1255,7 @@ class StableDiffusionXLPipeline(
# predict the noise residual # predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds added_cond_kwargs["image_embeds"] = image_embeds
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
......
...@@ -575,6 +575,8 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -575,6 +575,8 @@ class StableDiffusionXLImg2ImgPipeline(
negative_prompt_2=None, negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
...@@ -637,6 +639,11 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -637,6 +639,11 @@ class StableDiffusionXLImg2ImgPipeline(
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep # get the original timestep using init_timestep
if denoising_start is None: if denoising_start is None:
...@@ -767,32 +774,38 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -767,32 +774,38 @@ class StableDiffusionXLImg2ImgPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list): self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
ip_adapter_image = [ip_adapter_image] ):
if ip_adapter_image_embeds is None:
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): if not isinstance(ip_adapter_image, list):
raise ValueError( ip_adapter_image = [ip_adapter_image]
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = [] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
for single_ip_adapter_image, image_proj_layer in zip( raise ValueError(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
): )
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance: image_embeds = []
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) for single_ip_adapter_image, image_proj_layer in zip(
single_image_embeds = single_image_embeds.to(device) ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds) if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds return image_embeds
def _get_add_time_ids( def _get_add_time_ids(
...@@ -1047,6 +1060,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1047,6 +1060,7 @@ class StableDiffusionXLImg2ImgPipeline(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -1145,6 +1159,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1145,6 +1159,9 @@ class StableDiffusionXLImg2ImgPipeline(
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -1245,6 +1262,8 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1245,6 +1262,8 @@ class StableDiffusionXLImg2ImgPipeline(
negative_prompt_2, negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
) )
...@@ -1365,9 +1384,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1365,9 +1384,9 @@ class StableDiffusionXLImg2ImgPipeline(
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device) add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
) )
# 9. Denoising loop # 9. Denoising loop
...@@ -1416,7 +1435,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1416,7 +1435,7 @@ class StableDiffusionXLImg2ImgPipeline(
# predict the noise residual # predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds added_cond_kwargs["image_embeds"] = image_embeds
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
......
...@@ -488,32 +488,38 @@ class StableDiffusionXLInpaintPipeline( ...@@ -488,32 +488,38 @@ class StableDiffusionXLInpaintPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list): self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
ip_adapter_image = [ip_adapter_image] ):
if ip_adapter_image_embeds is None:
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): if not isinstance(ip_adapter_image, list):
raise ValueError( ip_adapter_image = [ip_adapter_image]
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = [] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
for single_ip_adapter_image, image_proj_layer in zip( raise ValueError(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
): )
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance: image_embeds = []
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) for single_ip_adapter_image, image_proj_layer in zip(
single_image_embeds = single_image_embeds.to(device) ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds) if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds return image_embeds
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
...@@ -784,6 +790,8 @@ class StableDiffusionXLInpaintPipeline( ...@@ -784,6 +790,8 @@ class StableDiffusionXLInpaintPipeline(
negative_prompt_2=None, negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None, padding_mask_crop=None,
): ):
...@@ -856,6 +864,11 @@ class StableDiffusionXLInpaintPipeline( ...@@ -856,6 +864,11 @@ class StableDiffusionXLInpaintPipeline(
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
def prepare_latents( def prepare_latents(
self, self,
batch_size, batch_size,
...@@ -1288,6 +1301,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1288,6 +1301,7 @@ class StableDiffusionXLInpaintPipeline(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -1397,6 +1411,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1397,6 +1411,9 @@ class StableDiffusionXLInpaintPipeline(
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -1512,6 +1529,8 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1512,6 +1529,8 @@ class StableDiffusionXLInpaintPipeline(
negative_prompt_2, negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
padding_mask_crop, padding_mask_crop,
) )
...@@ -1713,9 +1732,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1713,9 +1732,9 @@ class StableDiffusionXLInpaintPipeline(
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device) add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
) )
# 11. Denoising loop # 11. Denoising loop
...@@ -1766,7 +1785,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1766,7 +1785,7 @@ class StableDiffusionXLInpaintPipeline(
# predict the noise residual # predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds added_cond_kwargs["image_embeds"] = image_embeds
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
......
...@@ -564,32 +564,38 @@ class StableDiffusionXLAdapterPipeline( ...@@ -564,32 +564,38 @@ class StableDiffusionXLAdapterPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list): self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
ip_adapter_image = [ip_adapter_image] ):
if ip_adapter_image_embeds is None:
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): if not isinstance(ip_adapter_image, list):
raise ValueError( ip_adapter_image = [ip_adapter_image]
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = [] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
for single_ip_adapter_image, image_proj_layer in zip( raise ValueError(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
): )
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance: image_embeds = []
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) for single_ip_adapter_image, image_proj_layer in zip(
single_image_embeds = single_image_embeds.to(device) ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
image_embeds.append(single_image_embeds) if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -624,6 +630,8 @@ class StableDiffusionXLAdapterPipeline( ...@@ -624,6 +630,8 @@ class StableDiffusionXLAdapterPipeline(
negative_prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
...@@ -690,6 +698,11 @@ class StableDiffusionXLAdapterPipeline( ...@@ -690,6 +698,11 @@ class StableDiffusionXLAdapterPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
) )
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
...@@ -867,6 +880,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -867,6 +880,7 @@ class StableDiffusionXLAdapterPipeline(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -959,6 +973,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -959,6 +973,9 @@ class StableDiffusionXLAdapterPipeline(
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -1060,6 +1077,8 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1060,6 +1077,8 @@ class StableDiffusionXLAdapterPipeline(
negative_prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
...@@ -1096,9 +1115,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1096,9 +1115,9 @@ class StableDiffusionXLAdapterPipeline(
) )
# 3.2 Encode ip_adapter_image # 3.2 Encode ip_adapter_image
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
) )
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -1199,7 +1218,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1199,7 +1218,7 @@ class StableDiffusionXLAdapterPipeline(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds added_cond_kwargs["image_embeds"] = image_embeds
# predict the noise residual # predict the noise residual
......
...@@ -418,7 +418,6 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -418,7 +418,6 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
......
...@@ -495,7 +495,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -495,7 +495,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
......
...@@ -447,7 +447,6 @@ class TextToVideoZeroPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -447,7 +447,6 @@ class TextToVideoZeroPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
callback(step_idx, t, latents) callback(step_idx, t, latents)
return latents.clone().detach() return latents.clone().detach()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
......
...@@ -510,7 +510,6 @@ class TextToVideoZeroSDXLPipeline( ...@@ -510,7 +510,6 @@ class TextToVideoZeroSDXLPipeline(
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment