"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7f51f286a5397cb3e5c5a25693681aa4955e6241"
Unverified Commit 08b453e3 authored by Charchit Sharma's avatar Charchit Sharma Committed by GitHub
Browse files

IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (#5901)



* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 2a111bc9
...@@ -19,10 +19,10 @@ import numpy as np ...@@ -19,10 +19,10 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -130,7 +130,7 @@ def prepare_image(image): ...@@ -130,7 +130,7 @@ def prepare_image(image):
class StableDiffusionControlNetImg2ImgPipeline( class StableDiffusionControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
): ):
r""" r"""
Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance.
...@@ -140,7 +140,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -140,7 +140,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
The pipeline also inherits the following loading methods: The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
...@@ -166,7 +166,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -166,7 +166,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
...@@ -180,6 +180,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -180,6 +180,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -212,6 +213,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -212,6 +213,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
...@@ -468,6 +470,31 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -468,6 +470,31 @@ class StableDiffusionControlNetImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -861,6 +888,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -861,6 +888,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -922,6 +950,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -922,6 +950,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -1053,6 +1082,11 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1053,6 +1082,11 @@ class StableDiffusionControlNetImg2ImgPipeline(
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image # 4. Prepare image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
...@@ -1111,7 +1145,10 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1111,7 +1145,10 @@ class StableDiffusionControlNetImg2ImgPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep # 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 7.2 Create tensor stating which controlnets to keep
controlnet_keep = [] controlnet_keep = []
for i in range(len(timesteps)): for i in range(len(timesteps)):
keeps = [ keeps = [
...@@ -1171,6 +1208,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1171,6 +1208,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -134,6 +134,7 @@ class ControlNetImg2ImgPipelineFastTests( ...@@ -134,6 +134,7 @@ class ControlNetImg2ImgPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -273,6 +274,7 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -273,6 +274,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment