Unverified Commit 1d033a95 authored by Michael Gartsbein's avatar Michael Gartsbein Committed by GitHub
Browse files

img2img.multiple.controlnets.pipeline (#2833)



* img2img.multiple.controlnets.pipeline

* remove comments

---------
Co-authored-by: default avatarmishka <gartsocial@gmail.com>
parent 49609768
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/ # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -10,6 +10,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -10,6 +10,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import ( from diffusers.utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
...@@ -86,7 +87,14 @@ def prepare_image(image): ...@@ -86,7 +87,14 @@ def prepare_image(image):
def prepare_controlnet_conditioning_image( def prepare_controlnet_conditioning_image(
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype controlnet_conditioning_image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance,
): ):
if not isinstance(controlnet_conditioning_image, torch.Tensor): if not isinstance(controlnet_conditioning_image, torch.Tensor):
if isinstance(controlnet_conditioning_image, PIL.Image.Image): if isinstance(controlnet_conditioning_image, PIL.Image.Image):
...@@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image( ...@@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image(
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype) controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
return controlnet_conditioning_image return controlnet_conditioning_image
...@@ -132,7 +143,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -132,7 +143,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
controlnet: ControlNetModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
...@@ -156,6 +167,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -156,6 +167,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -424,6 +438,42 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -424,6 +438,42 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)
if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image)
else:
raise ValueError("controlnet condition image is not valid")
if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
else:
raise ValueError("prompt or prompt_embeds are not valid")
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
...@@ -438,6 +488,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -438,6 +488,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
strength=None, strength=None,
controlnet_guidance_start=None, controlnet_guidance_start=None,
controlnet_guidance_end=None, controlnet_guidance_end=None,
controlnet_conditioning_scale=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -476,58 +527,51 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -476,58 +527,51 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image) # check controlnet condition image
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
controlnet_conditioning_image[0], PIL.Image.Image
)
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
controlnet_conditioning_image[0], torch.Tensor
)
if ( if isinstance(self.controlnet, ControlNetModel):
not controlnet_cond_image_is_pil self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
and not controlnet_cond_image_is_tensor elif isinstance(self.controlnet, MultiControlNetModel):
and not controlnet_cond_image_is_pil_list if not isinstance(controlnet_conditioning_image, list):
and not controlnet_cond_image_is_tensor_list raise TypeError("For multiple controlnets: `image` must be type `list`")
):
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)
if controlnet_cond_image_is_pil: if len(controlnet_conditioning_image) != len(self.controlnet.nets):
controlnet_cond_image_batch_size = 1 raise ValueError(
elif controlnet_cond_image_is_tensor: "For multiple controlnets: `image` must have the same length as the number of controlnets."
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0] )
elif controlnet_cond_image_is_pil_list:
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
elif controlnet_cond_image_is_tensor_list:
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
if prompt is not None and isinstance(prompt, str): for image_ in controlnet_conditioning_image:
prompt_batch_size = 1 self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
elif prompt is not None and isinstance(prompt, list): else:
prompt_batch_size = len(prompt) assert False
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size: # Check `controlnet_conditioning_scale`
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}" if isinstance(self.controlnet, ControlNetModel):
) if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndim != 3 and image.ndim != 4: if image.ndim != 3 and image.ndim != 4:
raise ValueError("`image` must have 3 or 4 dimensions") raise ValueError("`image` must have 3 or 4 dimensions")
# if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
# raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
if image.ndim == 3: if image.ndim == 3:
image_batch_size = 1 image_batch_size = 1
image_channels, image_height, image_width = image.shape image_channels, image_height, image_width = image.shape
elif image.ndim == 4: elif image.ndim == 4:
image_batch_size, image_channels, image_height, image_width = image.shape image_batch_size, image_channels, image_height, image_width = image.shape
else:
assert False
if image_channels != 3: if image_channels != 3:
raise ValueError("`image` must have 3 channels") raise ValueError("`image` must have 3 channels")
...@@ -659,7 +703,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -659,7 +703,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
controlnet_guidance_start: float = 0.0, controlnet_guidance_start: float = 0.0,
controlnet_guidance_end: float = 1.0, controlnet_guidance_end: float = 1.0,
): ):
...@@ -759,7 +803,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -759,7 +803,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
self.check_inputs( self.check_inputs(
prompt, prompt,
image, image,
# mask_image,
controlnet_conditioning_image, controlnet_conditioning_image,
height, height,
width, width,
...@@ -770,6 +813,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -770,6 +813,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
strength, strength,
controlnet_guidance_start, controlnet_guidance_start,
controlnet_guidance_end, controlnet_guidance_end,
controlnet_conditioning_scale,
) )
# 2. Define call parameters # 2. Define call parameters
...@@ -786,6 +830,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -786,6 +830,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
...@@ -797,22 +844,41 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -797,22 +844,41 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Prepare mask, image, and controlnet_conditioning_image # 4. Prepare image, and controlnet_conditioning_image
image = prepare_image(image) image = prepare_image(image)
# mask_image = prepare_mask_image(mask_image) # condition image(s)
if isinstance(self.controlnet, ControlNetModel):
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=controlnet_conditioning_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
elif isinstance(self.controlnet, MultiControlNetModel):
controlnet_conditioning_images = []
for image_ in controlnet_conditioning_image:
image_ = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
controlnet_conditioning_image = prepare_controlnet_conditioning_image( controlnet_conditioning_images.append(image_)
controlnet_conditioning_image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)
# masked_image = image * (mask_image < 0.5) controlnet_conditioning_image = controlnet_conditioning_images
else:
assert False
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -830,9 +896,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -830,9 +896,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
generator, generator,
) )
if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
...@@ -862,15 +925,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -862,15 +925,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
controlnet_cond=controlnet_conditioning_image, controlnet_cond=controlnet_conditioning_image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False, return_dict=False,
) )
down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment