Unverified Commit 62906682 authored by timegate's avatar timegate Committed by GitHub
Browse files

Add multiple conditions to StableDiffusionControlNetInpaintPipeline (#3125)

* try multi controlnet inpaint

* multi controlnet inpaint

* multi controlnet inpaint
parent 73cc4310
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/ # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -11,6 +11,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -11,6 +11,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import ( from diffusers.utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
...@@ -184,7 +185,14 @@ def prepare_mask_image(mask_image): ...@@ -184,7 +185,14 @@ def prepare_mask_image(mask_image):
def prepare_controlnet_conditioning_image( def prepare_controlnet_conditioning_image(
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype controlnet_conditioning_image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance,
): ):
if not isinstance(controlnet_conditioning_image, torch.Tensor): if not isinstance(controlnet_conditioning_image, torch.Tensor):
if isinstance(controlnet_conditioning_image, PIL.Image.Image): if isinstance(controlnet_conditioning_image, PIL.Image.Image):
...@@ -214,6 +222,9 @@ def prepare_controlnet_conditioning_image( ...@@ -214,6 +222,9 @@ def prepare_controlnet_conditioning_image(
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype) controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
return controlnet_conditioning_image return controlnet_conditioning_image
...@@ -230,7 +241,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -230,7 +241,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
controlnet: ControlNetModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
...@@ -254,6 +265,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -254,6 +265,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -264,6 +278,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -264,6 +278,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
...@@ -522,6 +537,42 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -522,6 +537,42 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)
if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image)
else:
raise ValueError("controlnet condition image is not valid")
if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
else:
raise ValueError("prompt or prompt_embeds are not valid")
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
...@@ -534,6 +585,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -534,6 +585,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
controlnet_conditioning_scale=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -572,45 +624,35 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -572,45 +624,35 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image) # check controlnet condition image
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor) if isinstance(self.controlnet, ControlNetModel):
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance( self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
controlnet_conditioning_image[0], PIL.Image.Image elif isinstance(self.controlnet, MultiControlNetModel):
) if not isinstance(controlnet_conditioning_image, list):
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance( raise TypeError("For multiple controlnets: `image` must be type `list`")
controlnet_conditioning_image[0], torch.Tensor if len(controlnet_conditioning_image) != len(self.controlnet.nets):
) raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets."
if ( )
not controlnet_cond_image_is_pil for image_ in controlnet_conditioning_image:
and not controlnet_cond_image_is_tensor self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
and not controlnet_cond_image_is_pil_list else:
and not controlnet_cond_image_is_tensor_list assert False
):
raise TypeError( # Check `controlnet_conditioning_scale`
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" if isinstance(self.controlnet, ControlNetModel):
) if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
if controlnet_cond_image_is_pil: elif isinstance(self.controlnet, MultiControlNetModel):
controlnet_cond_image_batch_size = 1 if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
elif controlnet_cond_image_is_tensor: self.controlnet.nets
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0] ):
elif controlnet_cond_image_is_pil_list: raise ValueError(
controlnet_cond_image_batch_size = len(controlnet_conditioning_image) "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
elif controlnet_cond_image_is_tensor_list: " the same length as the number of controlnets"
controlnet_cond_image_batch_size = len(controlnet_conditioning_image) )
else:
if prompt is not None and isinstance(prompt, str): assert False
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
)
if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor): if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor") raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
...@@ -630,6 +672,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -630,6 +672,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
image_channels, image_height, image_width = image.shape image_channels, image_height, image_width = image.shape
elif image.ndim == 4: elif image.ndim == 4:
image_batch_size, image_channels, image_height, image_width = image.shape image_batch_size, image_channels, image_height, image_width = image.shape
else:
assert False
if mask_image.ndim == 2: if mask_image.ndim == 2:
mask_image_batch_size = 1 mask_image_batch_size = 1
...@@ -797,7 +841,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -797,7 +841,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -897,6 +941,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -897,6 +941,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
negative_prompt, negative_prompt,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
controlnet_conditioning_scale,
) )
# 2. Define call parameters # 2. Define call parameters
...@@ -913,6 +958,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -913,6 +958,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
...@@ -929,15 +977,37 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -929,15 +977,37 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
mask_image = prepare_mask_image(mask_image) mask_image = prepare_mask_image(mask_image)
controlnet_conditioning_image = prepare_controlnet_conditioning_image( # condition image(s)
controlnet_conditioning_image, if isinstance(self.controlnet, ControlNetModel):
width, controlnet_conditioning_image = prepare_controlnet_conditioning_image(
height, controlnet_conditioning_image=controlnet_conditioning_image,
batch_size * num_images_per_prompt, width=width,
num_images_per_prompt, height=height,
device, batch_size=batch_size * num_images_per_prompt,
self.controlnet.dtype, num_images_per_prompt=num_images_per_prompt,
) device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
elif isinstance(self.controlnet, MultiControlNetModel):
controlnet_conditioning_images = []
for image_ in controlnet_conditioning_image:
image_ = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
controlnet_conditioning_images.append(image_)
controlnet_conditioning_image = controlnet_conditioning_images
else:
assert False
masked_image = image * (mask_image < 0.5) masked_image = image * (mask_image < 0.5)
...@@ -979,9 +1049,6 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -979,9 +1049,6 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
do_classifier_free_guidance, do_classifier_free_guidance,
) )
if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
...@@ -1007,15 +1074,10 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -1007,15 +1074,10 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
controlnet_cond=controlnet_conditioning_image, controlnet_cond=controlnet_conditioning_image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False, return_dict=False,
) )
down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
inpainting_latent_model_input, inpainting_latent_model_input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment