Unverified Commit e8b65bff authored by hlky's avatar hlky Committed by GitHub
Browse files

refactor StableDiffusionXLControlNetUnion (#10200)

mode
parent f2d348d9
......@@ -15,7 +15,7 @@ if is_torch_available():
SparseControlNetModel,
SparseControlNetOutput,
)
from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel
from .controlnet_union import ControlNetUnionModel
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
from .multicontrolnet import MultiControlNetModel
......
......@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...image_processor import PipelineImageInput
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ..attention_processor import (
......@@ -40,76 +38,6 @@ from ..unets.unet_2d_condition import UNet2DConditionModel
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
@dataclass
class ControlNetUnionInput:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
"""
openpose: Optional[PipelineImageInput] = None
depth: Optional[PipelineImageInput] = None
hed: Optional[PipelineImageInput] = None
canny: Optional[PipelineImageInput] = None
normal: Optional[PipelineImageInput] = None
segment: Optional[PipelineImageInput] = None
def __len__(self) -> int:
return len(vars(self))
def __iter__(self):
return iter(vars(self))
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
setattr(self, key, value)
@dataclass
class ControlNetUnionInputProMax:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
- 6: tile
- 7: repaint
"""
openpose: Optional[PipelineImageInput] = None
depth: Optional[PipelineImageInput] = None
hed: Optional[PipelineImageInput] = None
canny: Optional[PipelineImageInput] = None
normal: Optional[PipelineImageInput] = None
segment: Optional[PipelineImageInput] = None
tile: Optional[PipelineImageInput] = None
repaint: Optional[PipelineImageInput] = None
def __len__(self) -> int:
return len(vars(self))
def __iter__(self):
return iter(vars(self))
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
setattr(self, key, value)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -680,8 +608,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor,
control_type_idx: List[int],
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
......@@ -701,11 +630,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
controlnet_cond (`List[torch.Tensor]`):
The conditional input tensors.
control_type (`torch.Tensor`):
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
type is used.
control_type_idx (`List[int]`):
The indices of `control_type`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
......@@ -733,20 +664,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(controlnet_cond) != self.config.num_control_type:
if isinstance(controlnet_cond, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(controlnet_cond, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`."
)
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
......@@ -830,12 +747,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs = []
condition_list = []
for idx, image_type in enumerate(controlnet_cond):
if controlnet_cond[image_type] is None:
continue
condition = self.controlnet_cond_embedding(controlnet_cond[image_type])
for cond, control_idx in zip(controlnet_cond, control_type_idx):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[idx]
feat_seq = feat_seq + self.task_embedding[control_idx]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
......
......@@ -40,7 +40,6 @@ from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
......@@ -82,7 +81,6 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
import numpy as np
......@@ -114,11 +112,8 @@ EXAMPLE_DOC_STRING = """
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
union_input = ControlNetUnionInputProMax(
repaint=controlnet_img,
)
# generate image
image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0]
image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0]
image.save("inpaint.png")
```
"""
......@@ -1130,7 +1125,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
......@@ -1158,6 +1153,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
......@@ -1345,20 +1341,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(control_image_list) != controlnet.config.num_control_type:
if isinstance(control_image_list, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(control_image_list, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
......@@ -1375,36 +1357,44 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list):
control_image = [control_image]
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = []
for image_type in control_image_list:
if control_image_list[image_type]:
self.check_inputs(
prompt,
prompt_2,
control_image_list[image_type],
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
control_type.append(1)
else:
control_type.append(0)
control_type = [0 for _ in range(num_control_type)]
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
control_type = torch.Tensor(control_type)
......@@ -1499,23 +1489,21 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
init_image = init_image.to(dtype=torch.float32)
# 5.2 Prepare control images
for image_type in control_image_list:
if control_image_list[image_type]:
control_image = self.prepare_control_image(
image=control_image_list[image_type],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
control_image_list[image_type] = control_image
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_control_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]
# 5.3 Prepare mask
mask = self.mask_processor.preprocess(
......@@ -1589,6 +1577,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
original_size = original_size or (height, width)
target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
# 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
......@@ -1693,8 +1684,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image_list,
controlnet_cond=control_image,
control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
......
......@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
......@@ -70,7 +69,6 @@ EXAMPLE_DOC_STRING = """
>>> # !pip install controlnet_aux
>>> from controlnet_aux import LineartAnimeDetector
>>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL
>>> from diffusers.models.controlnets import ControlNetUnionInput
>>> from diffusers.utils import load_image
>>> import torch
......@@ -89,17 +87,14 @@ EXAMPLE_DOC_STRING = """
... controlnet=controlnet,
... vae=vae,
... torch_dtype=torch.float16,
... variant="fp16",
... )
>>> pipe.enable_model_cpu_offload()
>>> # prepare image
>>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
>>> controlnet_img = processor(image, output_type="pil")
>>> # set ControlNetUnion input
>>> union_input = ControlNetUnionInput(
... canny=controlnet_img,
... )
>>> # generate image
>>> image = pipe(prompt, image=union_input).images[0]
>>> image = pipe(prompt, control_image=[controlnet_img], control_mode=[3], height=1024, width=1024).images[0]
```
"""
......@@ -791,26 +786,6 @@ class StableDiffusionXLControlNetUnionPipeline(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def check_input(
self,
image: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
):
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(image, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(image) != controlnet.config.num_control_type:
if isinstance(image, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(image, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInput`."
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_image(
self,
......@@ -970,7 +945,7 @@ class StableDiffusionXLControlNetUnionPipeline(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......@@ -997,6 +972,7 @@ class StableDiffusionXLControlNetUnionPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
......@@ -1018,10 +994,7 @@ class StableDiffusionXLControlNetUnionPipeline(
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders.
image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
control_image (`PipelineImageInput`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
......@@ -1168,38 +1141,45 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
self.check_input(image)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list):
control_image = [control_image]
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = [0 for _ in range(num_control_type)]
# 1. Check inputs. Raise error if not correct
control_type = []
for image_type in image:
if image[image_type]:
self.check_inputs(
prompt,
prompt_2,
image[image_type],
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
control_type.append(1)
else:
control_type.append(0)
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
control_type = torch.Tensor(control_type)
......@@ -1258,20 +1238,19 @@ class StableDiffusionXLControlNetUnionPipeline(
)
# 4. Prepare image
for image_type in image:
if image[image_type]:
image[image_type] = self.prepare_image(
image=image[image_type],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image[image_type].shape[-2:]
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
......@@ -1312,11 +1291,11 @@ class StableDiffusionXLControlNetUnionPipeline(
)
# 7.2 Prepare added time ids & embeddings
for image_type in image:
if isinstance(image[image_type], torch.Tensor):
original_size = original_size or image[image_type].shape[-2:]
original_size = original_size or (height, width)
target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
......@@ -1424,8 +1403,9 @@ class StableDiffusionXLControlNetUnionPipeline(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
controlnet_cond=control_image,
control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
......@@ -1478,7 +1458,6 @@ class StableDiffusionXLControlNetUnionPipeline(
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
image = callback_outputs.pop("image", image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
......@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
......@@ -74,7 +73,6 @@ EXAMPLE_DOC_STRING = """
ControlNetUnionModel,
AutoencoderKL,
)
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
from PIL import Image
......@@ -95,6 +93,7 @@ EXAMPLE_DOC_STRING = """
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
# `enable_model_cpu_offload` is not recommended due to multiple generations
height = image.height
......@@ -132,14 +131,12 @@ EXAMPLE_DOC_STRING = """
# set ControlNetUnion input
result_images = []
for sub_img, crops_coords in zip(images, crops_coords_list):
union_input = ControlNetUnionInputProMax(
tile=sub_img,
)
new_width, new_height = W, H
out = pipe(
prompt=[prompt] * 1,
image=sub_img,
control_image_list=union_input,
control_image=[sub_img],
control_mode=[6],
width=new_width,
height=new_height,
num_inference_steps=30,
......@@ -1065,7 +1062,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
......@@ -1090,6 +1087,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
......@@ -1119,10 +1117,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`)::
control_image (`PipelineImageInput`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
......@@ -1291,53 +1286,47 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(control_image_list) != controlnet.config.num_control_type:
if isinstance(control_image_list, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(control_image_list, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
# 1. Check inputs. Raise error if not correct
control_type = []
for image_type in control_image_list:
if control_image_list[image_type]:
self.check_inputs(
prompt,
prompt_2,
control_image_list[image_type],
strength,
num_inference_steps,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
control_type.append(1)
else:
control_type.append(0)
if not isinstance(control_image, list):
control_image = [control_image]
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = [0 for _ in range(num_control_type)]
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
strength,
num_inference_steps,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
control_type = torch.Tensor(control_type)
......@@ -1397,21 +1386,19 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
for image_type in control_image_list:
if control_image_list[image_type]:
control_image = self.prepare_control_image(
image=control_image_list[image_type],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
control_image_list[image_type] = control_image
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_control_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......@@ -1444,10 +1431,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
)
# 7.2 Prepare added time ids & embeddings
for image_type in control_image_list:
if isinstance(control_image_list[image_type], torch.Tensor):
original_size = original_size or control_image_list[image_type].shape[-2:]
original_size = original_size or (height, width)
target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
if negative_original_size is None:
negative_original_size = original_size
......@@ -1531,8 +1519,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image_list,
controlnet_cond=control_image,
control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment