Unverified Commit e0f33dfc authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

IP-Adapter support for StableDiffusionXLControlNetInpaintPipeline (#6941)



* add ip-adapter support

* support ip image embeds

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 15b125bb
...@@ -19,11 +19,17 @@ import numpy as np ...@@ -19,11 +19,17 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -195,6 +201,8 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -195,6 +201,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
requires_aesthetics_score: bool = False, requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True, force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None, add_watermarker: Optional[bool] = None,
feature_extractor: Optional[CLIPImageProcessor] = None,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
): ):
super().__init__() super().__init__()
...@@ -210,6 +218,8 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -210,6 +218,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
unet=unet, unet=unet,
controlnet=controlnet, controlnet=controlnet,
scheduler=scheduler, scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
...@@ -497,6 +507,66 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -497,6 +507,66 @@ class StableDiffusionXLControlNetInpaintPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...@@ -566,6 +636,8 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -566,6 +636,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
negative_prompt_2=None, negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
...@@ -752,6 +824,11 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -752,6 +824,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
if end > 1.0: if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
def prepare_control_image( def prepare_control_image(
self, self,
image, image,
...@@ -1100,6 +1177,8 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1100,6 +1177,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
...@@ -1194,6 +1273,10 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1194,6 +1273,10 @@ class StableDiffusionXLControlNetInpaintPipeline(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument. If not provided, pooled text embeddings will be generated from `prompt` input argument.
...@@ -1326,6 +1409,8 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1326,6 +1409,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
negative_prompt_2, negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
controlnet_conditioning_scale, controlnet_conditioning_scale,
...@@ -1378,6 +1463,12 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1378,6 +1463,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
) )
# 3.1 Encode ip_adapter_image
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
)
# 4. set timesteps # 4. set timesteps
def denoising_value_valid(dnv): def denoising_value_valid(dnv):
return isinstance(denoising_end, float) and 0 < dnv < 1 return isinstance(denoising_end, float) and 0 < dnv < 1
...@@ -1649,6 +1740,9 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1649,6 +1740,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
if ip_adapter_image is not None:
added_cond_kwargs["image_embeds"] = image_embeds
if num_channels_unet == 9: if num_channels_unet == 9:
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment