Unverified Commit 9f7b2cf2 authored by JuanCarlosPi's avatar JuanCarlosPi Committed by GitHub
Browse files

Support of ip-adapter to the StableDiffusionControlNetInpaintPipeline (#5887)



* Change pipeline_controlnet_inpaint.py to add ip-adapter support. Changes are similar to those in pipeline_controlnet

* Change tests for the StableDiffusionControlNetInpaintPipeline by adding image_encoder: None

* Update src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 895c4b70
...@@ -21,10 +21,10 @@ import numpy as np ...@@ -21,10 +21,10 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -241,7 +241,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False ...@@ -241,7 +241,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
class StableDiffusionControlNetInpaintPipeline( class StableDiffusionControlNetInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
): ):
r""" r"""
Pipeline for image inpainting using Stable Diffusion with ControlNet guidance. Pipeline for image inpainting using Stable Diffusion with ControlNet guidance.
...@@ -251,6 +251,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -251,6 +251,7 @@ class StableDiffusionControlNetInpaintPipeline(
The pipeline also inherits the following loading methods: The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
<Tip> <Tip>
...@@ -288,7 +289,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -288,7 +289,7 @@ class StableDiffusionControlNetInpaintPipeline(
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
...@@ -302,6 +303,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -302,6 +303,7 @@ class StableDiffusionControlNetInpaintPipeline(
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -334,6 +336,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -334,6 +336,7 @@ class StableDiffusionControlNetInpaintPipeline(
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
...@@ -593,6 +596,20 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -593,6 +596,20 @@ class StableDiffusionControlNetInpaintPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -1053,6 +1070,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1053,6 +1070,7 @@ class StableDiffusionControlNetInpaintPipeline(
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -1131,6 +1149,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1131,6 +1149,7 @@ class StableDiffusionControlNetInpaintPipeline(
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -1264,6 +1283,11 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1264,6 +1283,11 @@ class StableDiffusionControlNetInpaintPipeline(
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image # 4. Prepare image
if isinstance(controlnet, ControlNetModel): if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image( control_image = self.prepare_control_image(
...@@ -1299,7 +1323,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1299,7 +1323,7 @@ class StableDiffusionControlNetInpaintPipeline(
else: else:
assert False assert False
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
init_image = self.image_processor.preprocess(image, height=height, width=width) init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32) init_image = init_image.to(dtype=torch.float32)
...@@ -1360,7 +1384,10 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1360,7 +1384,10 @@ class StableDiffusionControlNetInpaintPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep # 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 7.2 Create tensor stating which controlnets to keep
controlnet_keep = [] controlnet_keep = []
for i in range(len(timesteps)): for i in range(len(timesteps)):
keeps = [ keeps = [
...@@ -1423,6 +1450,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1423,6 +1450,7 @@ class StableDiffusionControlNetInpaintPipeline(
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -132,6 +132,7 @@ class ControlNetInpaintPipelineFastTests( ...@@ -132,6 +132,7 @@ class ControlNetInpaintPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -248,6 +249,7 @@ class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTest ...@@ -248,6 +249,7 @@ class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTest
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -342,6 +344,7 @@ class MultiControlNetInpaintPipelineFastTests( ...@@ -342,6 +344,7 @@ class MultiControlNetInpaintPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment