Unverified Commit 3e8b6321 authored by antoine-scenario's avatar antoine-scenario Committed by GitHub
Browse files

Add IP-Adapter to StableDiffusionXLControlNetImg2ImgPipeline (#6293)



* add IP-Adapter to StableDiffusionXLControlNetImg2ImgPipeline

Update src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

fix tests

* fix failing test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent dd4459ad
...@@ -20,13 +20,23 @@ import numpy as np ...@@ -20,13 +20,23 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers.utils.import_utils import is_invisible_watermark_available from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import (
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -147,7 +157,7 @@ def retrieve_latents( ...@@ -147,7 +157,7 @@ def retrieve_latents(
class StableDiffusionXLControlNetImg2ImgPipeline( class StableDiffusionXLControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin
): ):
r""" r"""
Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
...@@ -159,6 +169,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -159,6 +169,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -197,10 +208,19 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -197,10 +208,19 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used. watermarker will be used.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] _optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"feature_extractor",
"image_encoder",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
...@@ -216,6 +236,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -216,6 +236,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
requires_aesthetics_score: bool = False, requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True, force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None, add_watermarker: Optional[bool] = None,
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
): ):
super().__init__() super().__init__()
...@@ -231,6 +253,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -231,6 +253,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
unet=unet, unet=unet,
controlnet=controlnet, controlnet=controlnet,
scheduler=scheduler, scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
...@@ -515,6 +539,31 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -515,6 +539,31 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...@@ -1011,6 +1060,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1011,6 +1060,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -1109,6 +1159,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1109,6 +1159,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -1262,7 +1313,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1262,7 +1313,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
) )
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt # 3.1. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
...@@ -1287,6 +1338,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1287,6 +1338,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
) )
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image and controlnet_conditioning_image # 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
...@@ -1449,6 +1509,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1449,6 +1509,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
if ip_adapter_image is not None:
added_cond_kwargs["image_embeds"] = image_embeds
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
......
...@@ -136,6 +136,8 @@ class ControlNetPipelineSDXLImg2ImgFastTests( ...@@ -136,6 +136,8 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
"tokenizer": tokenizer if not skip_first_text_encoder else None, "tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
} }
return components return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment