Unverified Commit aa14f090 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[ControlnetUnion] Propagate #11888 to img2img (#11929)

img2img fixes
parent c5d6e0b5
...@@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F
from transformers import ( from transformers import (
CLIPImageProcessor, CLIPImageProcessor,
CLIPTextModel, CLIPTextModel,
...@@ -38,7 +37,13 @@ from ...loaders import ( ...@@ -38,7 +37,13 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel from ...models import (
AutoencoderKL,
ControlNetUnionModel,
ImageProjection,
MultiControlNetUnionModel,
UNet2DConditionModel,
)
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
...@@ -262,7 +267,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -262,7 +267,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer, tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
controlnet: ControlNetUnionModel, controlnet: Union[
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
],
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False, requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True, force_zeros_for_empty_prompt: bool = True,
...@@ -272,8 +279,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -272,8 +279,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
): ):
super().__init__() super().__init__()
if not isinstance(controlnet, ControlNetUnionModel): if isinstance(controlnet, (list, tuple)):
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.") controlnet = MultiControlNetUnionModel(controlnet)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
...@@ -649,6 +656,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -649,6 +656,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
control_mode=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
...@@ -722,28 +730,44 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -722,28 +730,44 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
) )
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetUnionModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image` # Check `image`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
) if isinstance(controlnet, ControlNetUnionModel):
if ( for image_ in image:
isinstance(self.controlnet, ControlNetModel) self.check_image(image_, prompt, prompt_embeds)
or is_compiled elif isinstance(controlnet, MultiControlNetUnionModel):
and isinstance(self.controlnet._orig_mod, ControlNetModel) if not isinstance(image, list):
): raise TypeError("For multiple controlnets: `image` must be type `list`")
self.check_image(image, prompt, prompt_embeds) elif not all(isinstance(i, list) for i in image):
elif ( raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
isinstance(self.controlnet, ControlNetUnionModel) elif len(image) != len(self.controlnet.nets):
or is_compiled raise ValueError(
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel) f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
): )
self.check_image(image, prompt, prompt_embeds)
else: for images_ in image:
assert False for image_ in images_:
self.check_image(image_, prompt, prompt_embeds)
if not isinstance(control_guidance_start, (tuple, list)): if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start] control_guidance_start = [control_guidance_start]
if isinstance(controlnet, MultiControlNetUnionModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
if not isinstance(control_guidance_end, (tuple, list)): if not isinstance(control_guidance_end, (tuple, list)):
control_guidance_end = [control_guidance_end] control_guidance_end = [control_guidance_end]
...@@ -762,6 +786,15 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -762,6 +786,15 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if end > 1.0: if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Check `control_mode`
if isinstance(controlnet, ControlNetUnionModel):
if max(control_mode) >= controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
elif isinstance(controlnet, MultiControlNetUnionModel):
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
...@@ -1049,7 +1082,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1049,7 +1082,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
control_image: PipelineImageInput = None, control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 0.8, strength: float = 0.8,
...@@ -1074,7 +1107,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1074,7 +1107,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None, control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None, target_size: Tuple[int, int] = None,
...@@ -1104,13 +1137,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1104,13 +1137,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again. image latents as `image`, if passing latents directly, it will not be encoded again.
control_image (`PipelineImageInput`): control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
init, images must be passed as a list such that each element of the list can be correctly batched for images must be passed as a list such that each element of the list can be correctly batched for input
input to a single controlnet. to a single ControlNet.
height (`int`, *optional*, defaults to the size of control_image): height (`int`, *optional*, defaults to the size of control_image):
The height in pixels of the generated image. Anything below 512 pixels won't work well for The height in pixels of the generated image. Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
...@@ -1184,16 +1217,21 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1184,16 +1217,21 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`self.processor` in `self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
corresponding scale as a list. the corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`): guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying. The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying. The percentage of total steps at which the ControlNet stops applying.
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
where each ControlNet should have its corresponding control mode list. Should reflect the order of
conditions in control_image
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
...@@ -1273,12 +1311,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1273,12 +1311,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list): if not isinstance(control_image, list):
control_image = [control_image] control_image = [control_image]
else: else:
...@@ -1287,37 +1319,56 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1287,37 +1319,56 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if not isinstance(control_mode, list): if not isinstance(control_mode, list):
control_mode = [control_mode] control_mode = [control_mode]
if len(control_image) != len(control_mode): if isinstance(controlnet, MultiControlNetUnionModel):
raise ValueError("Expected len(control_image) == len(control_type)") control_image = [[item] for item in control_image]
control_mode = [[item] for item in control_mode]
num_control_type = controlnet.config.num_control_type # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if isinstance(controlnet_conditioning_scale, float):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
# 1. Check inputs # 1. Check inputs
control_type = [0 for _ in range(num_control_type)] self.check_inputs(
for _image, control_idx in zip(control_image, control_mode): prompt,
control_type[control_idx] = 1 prompt_2,
self.check_inputs( control_image,
prompt, strength,
prompt_2, num_inference_steps,
_image, callback_steps,
strength, negative_prompt,
num_inference_steps, negative_prompt_2,
callback_steps, prompt_embeds,
negative_prompt, negative_prompt_embeds,
negative_prompt_2, pooled_prompt_embeds,
prompt_embeds, negative_pooled_prompt_embeds,
negative_prompt_embeds, ip_adapter_image,
pooled_prompt_embeds, ip_adapter_image_embeds,
negative_pooled_prompt_embeds, controlnet_conditioning_scale,
ip_adapter_image, control_guidance_start,
ip_adapter_image_embeds, control_guidance_end,
controlnet_conditioning_scale, control_mode,
control_guidance_start, callback_on_step_end_tensor_inputs,
control_guidance_end, )
callback_on_step_end_tensor_inputs,
)
control_type = torch.Tensor(control_type) if isinstance(controlnet, ControlNetUnionModel):
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
]
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
...@@ -1334,7 +1385,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1334,7 +1385,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
device = self._execution_device device = self._execution_device
global_pool_conditions = controlnet.config.global_pool_conditions global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetUnionModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3.1. Encode input prompt # 3.1. Encode input prompt
...@@ -1372,22 +1427,55 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1372,22 +1427,55 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
self.do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
# 4. Prepare image and controlnet_conditioning_image # 4.1 Prepare image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
for idx, _ in enumerate(control_image): # 4.2 Prepare control images
control_image[idx] = self.prepare_control_image( if isinstance(controlnet, ControlNetUnionModel):
image=control_image[idx], control_images = []
width=width,
height=height, for image_ in control_image:
batch_size=batch_size * num_images_per_prompt, image_ = self.prepare_control_image(
num_images_per_prompt=num_images_per_prompt, image=image_,
device=device, width=width,
dtype=controlnet.dtype, height=height,
do_classifier_free_guidance=self.do_classifier_free_guidance, batch_size=batch_size * num_images_per_prompt,
guess_mode=guess_mode, num_images_per_prompt=num_images_per_prompt,
) device=device,
height, width = control_image[idx].shape[-2:] dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
control_images.append(image_)
control_image = control_images
height, width = control_image[0].shape[-2:]
elif isinstance(controlnet, MultiControlNetUnionModel):
control_images = []
for control_image_ in control_image:
images = []
for image_ in control_image_:
image_ = self.prepare_control_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
images.append(image_)
control_images.append(images)
control_image = control_images
height, width = control_image[0][0].shape[-2:]
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -1414,10 +1502,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1414,10 +1502,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 7.1 Create tensor stating which controlnets to keep # 7.1 Create tensor stating which controlnets to keep
controlnet_keep = [] controlnet_keep = []
for i in range(len(timesteps)): for i in range(len(timesteps)):
controlnet_keep.append( keeps = [
1.0 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) for s, e in zip(control_guidance_start, control_guidance_end)
) ]
controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings # 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width) original_size = original_size or (height, width)
...@@ -1460,12 +1549,25 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1460,12 +1549,25 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt_embeds = prompt_embeds.to(device) prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device) add_time_ids = add_time_ids.to(device)
control_type = (
control_type.reshape(1, -1) control_type_repeat_factor = (
.to(device, dtype=prompt_embeds.dtype) batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
.repeat(batch_size * num_images_per_prompt * 2, 1)
) )
if isinstance(controlnet, ControlNetUnionModel):
control_type = (
control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(control_type_repeat_factor, 1)
)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
_control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(control_type_repeat_factor, 1)
for _control_type in control_type
]
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment