Unverified Commit 5105b5a8 authored by Daniel Regado's avatar Daniel Regado Committed by GitHub
Browse files

`MultiControlNetUnionModel` on SDXL (#10747)



* SDXL with MultiControlNetUnionModel



---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent ca6330dc
......@@ -51,6 +51,7 @@ if is_torch_available():
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
......@@ -122,6 +123,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DMultiControlNetModel,
MultiControlNetModel,
MultiControlNetUnionModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SparseControlNetModel,
......
......@@ -18,6 +18,7 @@ if is_torch_available():
from .controlnet_union import ControlNetUnionModel
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
from .multicontrolnet import MultiControlNetModel
from .multicontrolnet_union import MultiControlNetUnionModel
if is_flax_available():
from .controlnet_flax import FlaxControlNetModel
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from ...models.controlnets.controlnet import ControlNetOutput
from ...models.controlnets.controlnet_union import ControlNetUnionModel
from ...models.modeling_utils import ModelMixin
from ...utils import logging
logger = logging.get_logger(__name__)
class MultiControlNetUnionModel(ModelMixin):
r"""
Multiple `ControlNetUnionModel` wrapper class for Multi-ControlNet-Union.
This module is a wrapper for multiple instances of the `ControlNetUnionModel`. The `forward()` API is designed to
be compatible with `ControlNetUnionModel`.
Args:
controlnets (`List[ControlNetUnionModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`ControlNetUnionModel` as a list.
"""
def __init__(self, controlnets: Union[List[ControlNetUnionModel], Tuple[ControlNetUnionModel]]):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
control_type: List[torch.Tensor],
control_type_idx: List[List[int]],
conditioning_scale: List[float],
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
):
down_samples, mid_sample = controlnet(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=image,
control_type=ctype,
control_type_idx=ctype_idx,
conditioning_scale=scale,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
guess_mode=guess_mode,
return_dict=return_dict,
)
# merge samples
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained with ControlNet->ControlNetUnion
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
for idx, controlnet in enumerate(self.nets):
suffix = "" if idx == 0 else f"_{idx}"
controlnet.save_pretrained(
save_directory + suffix,
is_main_process=is_main_process,
save_function=save_function,
safe_serialization=safe_serialization,
variant=variant,
)
@classmethod
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained with ControlNet->ControlNetUnion
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained MultiControlNetUnion model from multiple pre-trained controlnet models.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, you should first set it back in training mode with `model.train()`.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_path (`os.PathLike`):
A path to a *directory* containing model weights saved using
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
`./my_model_directory/controlnet`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
"""
idx = 0
controlnets = []
# load controlnet and append to list until no controlnet directory exists anymore
# first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
# second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
model_path_to_load = pretrained_model_path
while os.path.isdir(model_path_to_load):
controlnet = ControlNetUnionModel.from_pretrained(model_path_to_load, **kwargs)
controlnets.append(controlnet)
idx += 1
model_path_to_load = pretrained_model_path + f"_{idx}"
logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
if len(controlnets) == 0:
raise ValueError(
f"No ControlNetUnions found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
)
return cls(controlnets)
......@@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
......@@ -38,7 +37,13 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
from ...models import (
AutoencoderKL,
ControlNetUnionModel,
ImageProjection,
MultiControlNetUnionModel,
UNet2DConditionModel,
)
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
......@@ -244,7 +249,9 @@ class StableDiffusionXLControlNetUnionPipeline(
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetUnionModel,
controlnet: Union[
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
],
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
......@@ -253,8 +260,8 @@ class StableDiffusionXLControlNetUnionPipeline(
):
super().__init__()
if not isinstance(controlnet, ControlNetUnionModel):
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetUnionModel(controlnet)
self.register_modules(
vae=vae,
......@@ -664,6 +671,7 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
control_mode=None,
callback_on_step_end_tensor_inputs=None,
):
if callback_on_step_end_tensor_inputs is not None and not all(
......@@ -721,46 +729,102 @@ class StableDiffusionXLControlNetUnionPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetUnionModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
self.check_image(image, prompt, prompt_embeds)
elif (
isinstance(self.controlnet, ControlNetUnionModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
):
self.check_image(image, prompt, prompt_embeds)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if isinstance(controlnet, ControlNetUnionModel):
for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
elif isinstance(controlnet, MultiControlNetUnionModel):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
elif not all(isinstance(i, list) for i in image):
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
)
for images_ in image:
for image_ in images_:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False
if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start]
if not isinstance(control_guidance_end, (tuple, list)):
control_guidance_end = [control_guidance_end]
# Check `controlnet_conditioning_scale`
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
if isinstance(controlnet, ControlNetUnionModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
if isinstance(controlnet, MultiControlNetUnionModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.")
# Check `control_mode`
if isinstance(controlnet, ControlNetUnionModel):
if max(control_mode) >= controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
elif isinstance(controlnet, MultiControlNetUnionModel):
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
else:
assert False
# Equal number of `image` and `control_mode` elements
if isinstance(controlnet, ControlNetUnionModel):
if len(image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_mode)")
elif isinstance(controlnet, MultiControlNetUnionModel):
if not all(isinstance(i, list) for i in control_mode):
raise ValueError(
"For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
)
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
raise ValueError("Expected len(control_image) == len(control_mode)")
else:
assert False
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
......@@ -936,7 +1000,7 @@ class StableDiffusionXLControlNetUnionPipeline(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
control_image: PipelineImageInput = None,
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......@@ -963,7 +1027,7 @@ class StableDiffusionXLControlNetUnionPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
......@@ -985,7 +1049,7 @@ class StableDiffusionXLControlNetUnionPipeline(
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders.
control_image (`PipelineImageInput`):
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
......@@ -1077,6 +1141,11 @@ class StableDiffusionXLControlNetUnionPipeline(
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
where each ControlNet should have its corresponding control mode list. Should reflect the order of
conditions in control_image.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
......@@ -1137,6 +1206,12 @@ class StableDiffusionXLControlNetUnionPipeline(
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if not isinstance(control_image, list):
control_image = [control_image]
......@@ -1146,35 +1221,36 @@ class StableDiffusionXLControlNetUnionPipeline(
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
# 1. Check inputs
control_type = [0 for _ in range(num_control_type)]
# 1. Check inputs. Raise error if not correct
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self.check_inputs(
prompt,
prompt_2,
control_image,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
control_mode,
callback_on_step_end_tensor_inputs,
)
control_type = torch.Tensor(control_type)
if isinstance(controlnet, ControlNetUnionModel):
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
]
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
......@@ -1192,7 +1268,11 @@ class StableDiffusionXLControlNetUnionPipeline(
device = self._execution_device
global_pool_conditions = controlnet.config.global_pool_conditions
global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetUnionModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions
# 3.1 Encode input prompt
......@@ -1231,19 +1311,54 @@ class StableDiffusionXLControlNetUnionPipeline(
)
# 4. Prepare image
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]
if isinstance(controlnet, ControlNetUnionModel):
control_images = []
for image_ in control_image:
image_ = self.prepare_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
control_images.append(image_)
control_image = control_images
height, width = control_image[0].shape[-2:]
elif isinstance(controlnet, MultiControlNetUnionModel):
control_images = []
for control_image_ in control_image:
images = []
for image_ in control_image_:
image_ = self.prepare_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
images.append(image_)
control_images.append(images)
control_image = control_images
height, width = control_image[0][0].shape[-2:]
else:
assert False
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
......@@ -1278,10 +1393,11 @@ class StableDiffusionXLControlNetUnionPipeline(
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
controlnet_keep.append(
1.0
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
)
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps)
# 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width)
......@@ -1346,11 +1462,20 @@ class StableDiffusionXLControlNetUnionPipeline(
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
control_type = (
control_type.reshape(1, -1)
.to(device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1)
)
if isinstance(controlnet, ControlNetUnionModel):
control_type = (
control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1)
)
if isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
_control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1)
for _control_type in control_type
]
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment