Unverified Commit e8b65bff authored by hlky's avatar hlky Committed by GitHub
Browse files

refactor StableDiffusionXLControlNetUnion (#10200)

mode
parent f2d348d9
...@@ -15,7 +15,7 @@ if is_torch_available(): ...@@ -15,7 +15,7 @@ if is_torch_available():
SparseControlNetModel, SparseControlNetModel,
SparseControlNetOutput, SparseControlNetOutput,
) )
from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel from .controlnet_union import ControlNetUnionModel
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
from .multicontrolnet import MultiControlNetModel from .multicontrolnet import MultiControlNetModel
......
...@@ -11,14 +11,12 @@ ...@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...image_processor import PipelineImageInput
from ...loaders.single_file_model import FromOriginalModelMixin from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging from ...utils import logging
from ..attention_processor import ( from ..attention_processor import (
...@@ -40,76 +38,6 @@ from ..unets.unet_2d_condition import UNet2DConditionModel ...@@ -40,76 +38,6 @@ from ..unets.unet_2d_condition import UNet2DConditionModel
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
@dataclass
class ControlNetUnionInput:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
"""
openpose: Optional[PipelineImageInput] = None
depth: Optional[PipelineImageInput] = None
hed: Optional[PipelineImageInput] = None
canny: Optional[PipelineImageInput] = None
normal: Optional[PipelineImageInput] = None
segment: Optional[PipelineImageInput] = None
def __len__(self) -> int:
return len(vars(self))
def __iter__(self):
return iter(vars(self))
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
setattr(self, key, value)
@dataclass
class ControlNetUnionInputProMax:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
- 6: tile
- 7: repaint
"""
openpose: Optional[PipelineImageInput] = None
depth: Optional[PipelineImageInput] = None
hed: Optional[PipelineImageInput] = None
canny: Optional[PipelineImageInput] = None
normal: Optional[PipelineImageInput] = None
segment: Optional[PipelineImageInput] = None
tile: Optional[PipelineImageInput] = None
repaint: Optional[PipelineImageInput] = None
def __len__(self) -> int:
return len(vars(self))
def __iter__(self):
return iter(vars(self))
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
setattr(self, key, value)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -680,8 +608,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -680,8 +608,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
sample: torch.Tensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax], controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor, control_type: torch.Tensor,
control_type_idx: List[int],
conditioning_scale: float = 1.0, conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
...@@ -701,11 +630,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -701,11 +630,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The number of timesteps to denoise an input. The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`): encoder_hidden_states (`torch.Tensor`):
The encoder hidden states. The encoder hidden states.
controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): controlnet_cond (`List[torch.Tensor]`):
The conditional input tensors. The conditional input tensors.
control_type (`torch.Tensor`): control_type (`torch.Tensor`):
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
type is used. type is used.
control_type_idx (`List[int]`):
The indices of `control_type`.
conditioning_scale (`float`, defaults to `1.0`): conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs. The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`): class_labels (`torch.Tensor`, *optional*, defaults to `None`):
...@@ -733,20 +664,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -733,20 +664,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor. returned where the first element is the sample tensor.
""" """
if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(controlnet_cond) != self.config.num_control_type:
if isinstance(controlnet_cond, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(controlnet_cond, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`."
)
# check channel order # check channel order
channel_order = self.config.controlnet_conditioning_channel_order channel_order = self.config.controlnet_conditioning_channel_order
...@@ -830,12 +747,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -830,12 +747,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs = [] inputs = []
condition_list = [] condition_list = []
for idx, image_type in enumerate(controlnet_cond): for cond, control_idx in zip(controlnet_cond, control_type_idx):
if controlnet_cond[image_type] is None: condition = self.controlnet_cond_embedding(cond)
continue
condition = self.controlnet_cond_embedding(controlnet_cond[image_type])
feat_seq = torch.mean(condition, dim=(2, 3)) feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[idx] feat_seq = feat_seq + self.task_embedding[control_idx]
inputs.append(feat_seq.unsqueeze(1)) inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition) condition_list.append(condition)
......
...@@ -40,7 +40,6 @@ from ...models.attention_processor import ( ...@@ -40,7 +40,6 @@ from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -82,7 +81,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -82,7 +81,6 @@ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image from diffusers.utils import load_image
import torch import torch
import numpy as np import numpy as np
...@@ -114,11 +112,8 @@ EXAMPLE_DOC_STRING = """ ...@@ -114,11 +112,8 @@ EXAMPLE_DOC_STRING = """
mask_np = np.array(mask) mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0 controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np) controlnet_img = Image.fromarray(controlnet_img_np)
union_input = ControlNetUnionInputProMax(
repaint=controlnet_img,
)
# generate image # generate image
image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0] image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0]
image.save("inpaint.png") image.save("inpaint.png")
``` ```
""" """
...@@ -1130,7 +1125,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1130,7 +1125,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
mask_image: PipelineImageInput = None, mask_image: PipelineImageInput = None,
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, control_image: PipelineImageInput = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
padding_mask_crop: Optional[int] = None, padding_mask_crop: Optional[int] = None,
...@@ -1158,6 +1153,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1158,6 +1153,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
...@@ -1345,20 +1341,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1345,20 +1341,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(control_image_list) != controlnet.config.num_control_type:
if isinstance(control_image_list, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(control_image_list, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
)
# align format for control guidance # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start] control_guidance_start = len(control_guidance_end) * [control_guidance_start]
...@@ -1375,14 +1357,25 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1375,14 +1357,25 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end] control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list):
control_image = [control_image]
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs # 1. Check inputs
control_type = [] control_type = [0 for _ in range(num_control_type)]
for image_type in control_image_list: for _image, control_idx in zip(control_image, control_mode):
if control_image_list[image_type]: control_type[control_idx] = 1
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2, prompt_2,
control_image_list[image_type], _image,
mask_image, mask_image,
strength, strength,
num_inference_steps, num_inference_steps,
...@@ -1402,9 +1395,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1402,9 +1395,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
padding_mask_crop, padding_mask_crop,
) )
control_type.append(1)
else:
control_type.append(0)
control_type = torch.Tensor(control_type) control_type = torch.Tensor(control_type)
...@@ -1499,10 +1489,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1499,10 +1489,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
init_image = init_image.to(dtype=torch.float32) init_image = init_image.to(dtype=torch.float32)
# 5.2 Prepare control images # 5.2 Prepare control images
for image_type in control_image_list: for idx, _ in enumerate(control_image):
if control_image_list[image_type]: control_image[idx] = self.prepare_control_image(
control_image = self.prepare_control_image( image=control_image[idx],
image=control_image_list[image_type],
width=width, width=width,
height=height, height=height,
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
...@@ -1514,8 +1503,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1514,8 +1503,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
do_classifier_free_guidance=self.do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
height, width = control_image.shape[-2:] height, width = control_image[idx].shape[-2:]
control_image_list[image_type] = control_image
# 5.3 Prepare mask # 5.3 Prepare mask
mask = self.mask_processor.preprocess( mask = self.mask_processor.preprocess(
...@@ -1589,6 +1577,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1589,6 +1577,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
original_size = original_size or (height, width) original_size = original_size or (height, width)
target_size = target_size or (height, width) target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
# 10. Prepare added time ids & embeddings # 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds add_text_embeds = pooled_prompt_embeds
...@@ -1693,8 +1684,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1693,8 +1684,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
control_model_input, control_model_input,
t, t,
encoder_hidden_states=controlnet_prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image_list, controlnet_cond=control_image,
control_type=control_type, control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale, conditioning_scale=cond_scale,
guess_mode=guess_mode, guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs, added_cond_kwargs=controlnet_added_cond_kwargs,
......
...@@ -43,7 +43,6 @@ from ...models.attention_processor import ( ...@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -70,7 +69,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -70,7 +69,6 @@ EXAMPLE_DOC_STRING = """
>>> # !pip install controlnet_aux >>> # !pip install controlnet_aux
>>> from controlnet_aux import LineartAnimeDetector >>> from controlnet_aux import LineartAnimeDetector
>>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL >>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL
>>> from diffusers.models.controlnets import ControlNetUnionInput
>>> from diffusers.utils import load_image >>> from diffusers.utils import load_image
>>> import torch >>> import torch
...@@ -89,17 +87,14 @@ EXAMPLE_DOC_STRING = """ ...@@ -89,17 +87,14 @@ EXAMPLE_DOC_STRING = """
... controlnet=controlnet, ... controlnet=controlnet,
... vae=vae, ... vae=vae,
... torch_dtype=torch.float16, ... torch_dtype=torch.float16,
... variant="fp16",
... ) ... )
>>> pipe.enable_model_cpu_offload() >>> pipe.enable_model_cpu_offload()
>>> # prepare image >>> # prepare image
>>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") >>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
>>> controlnet_img = processor(image, output_type="pil") >>> controlnet_img = processor(image, output_type="pil")
>>> # set ControlNetUnion input
>>> union_input = ControlNetUnionInput(
... canny=controlnet_img,
... )
>>> # generate image >>> # generate image
>>> image = pipe(prompt, image=union_input).images[0] >>> image = pipe(prompt, control_image=[controlnet_img], control_mode=[3], height=1024, width=1024).images[0]
``` ```
""" """
...@@ -791,26 +786,6 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -791,26 +786,6 @@ class StableDiffusionXLControlNetUnionPipeline(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def check_input(
self,
image: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
):
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(image, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(image) != controlnet.config.num_control_type:
if isinstance(image, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(image, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInput`."
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_image( def prepare_image(
self, self,
...@@ -970,7 +945,7 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -970,7 +945,7 @@ class StableDiffusionXLControlNetUnionPipeline(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, control_image: PipelineImageInput = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -997,6 +972,7 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -997,6 +972,7 @@ class StableDiffusionXLControlNetUnionPipeline(
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None, target_size: Tuple[int, int] = None,
...@@ -1018,10 +994,7 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1018,10 +994,7 @@ class StableDiffusionXLControlNetUnionPipeline(
prompt_2 (`str` or `List[str]`, *optional*): prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders. used in both text-encoders.
image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): control_image (`PipelineImageInput`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
...@@ -1168,22 +1141,32 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1168,22 +1141,32 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
self.check_input(image)
# align format for control guidance # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start] control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end] control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list):
control_image = [control_image]
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = [0 for _ in range(num_control_type)]
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
control_type = [] for _image, control_idx in zip(control_image, control_mode):
for image_type in image: control_type[control_idx] = 1
if image[image_type]:
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2, prompt_2,
image[image_type], _image,
negative_prompt, negative_prompt,
negative_prompt_2, negative_prompt_2,
prompt_embeds, prompt_embeds,
...@@ -1197,9 +1180,6 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1197,9 +1180,6 @@ class StableDiffusionXLControlNetUnionPipeline(
control_guidance_end, control_guidance_end,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
) )
control_type.append(1)
else:
control_type.append(0)
control_type = torch.Tensor(control_type) control_type = torch.Tensor(control_type)
...@@ -1258,10 +1238,9 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1258,10 +1238,9 @@ class StableDiffusionXLControlNetUnionPipeline(
) )
# 4. Prepare image # 4. Prepare image
for image_type in image: for idx, _ in enumerate(control_image):
if image[image_type]: control_image[idx] = self.prepare_image(
image[image_type] = self.prepare_image( image=control_image[idx],
image=image[image_type],
width=width, width=width,
height=height, height=height,
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
...@@ -1271,7 +1250,7 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1271,7 +1250,7 @@ class StableDiffusionXLControlNetUnionPipeline(
do_classifier_free_guidance=self.do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
height, width = image[image_type].shape[-2:] height, width = control_image[idx].shape[-2:]
# 5. Prepare timesteps # 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
...@@ -1312,11 +1291,11 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1312,11 +1291,11 @@ class StableDiffusionXLControlNetUnionPipeline(
) )
# 7.2 Prepare added time ids & embeddings # 7.2 Prepare added time ids & embeddings
for image_type in image: original_size = original_size or (height, width)
if isinstance(image[image_type], torch.Tensor):
original_size = original_size or image[image_type].shape[-2:]
target_size = target_size or (height, width) target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
add_text_embeds = pooled_prompt_embeds add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None: if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
...@@ -1424,8 +1403,9 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1424,8 +1403,9 @@ class StableDiffusionXLControlNetUnionPipeline(
control_model_input, control_model_input,
t, t,
encoder_hidden_states=controlnet_prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image, controlnet_cond=control_image,
control_type=control_type, control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale, conditioning_scale=cond_scale,
guess_mode=guess_mode, guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs, added_cond_kwargs=controlnet_added_cond_kwargs,
...@@ -1478,7 +1458,6 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1478,7 +1458,6 @@ class StableDiffusionXLControlNetUnionPipeline(
) )
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
image = callback_outputs.pop("image", image)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -43,7 +43,6 @@ from ...models.attention_processor import ( ...@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -74,7 +73,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -74,7 +73,6 @@ EXAMPLE_DOC_STRING = """
ControlNetUnionModel, ControlNetUnionModel,
AutoencoderKL, AutoencoderKL,
) )
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image from diffusers.utils import load_image
import torch import torch
from PIL import Image from PIL import Image
...@@ -95,6 +93,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -95,6 +93,7 @@ EXAMPLE_DOC_STRING = """
controlnet=controlnet, controlnet=controlnet,
vae=vae, vae=vae,
torch_dtype=torch.float16, torch_dtype=torch.float16,
variant="fp16",
).to("cuda") ).to("cuda")
# `enable_model_cpu_offload` is not recommended due to multiple generations # `enable_model_cpu_offload` is not recommended due to multiple generations
height = image.height height = image.height
...@@ -132,14 +131,12 @@ EXAMPLE_DOC_STRING = """ ...@@ -132,14 +131,12 @@ EXAMPLE_DOC_STRING = """
# set ControlNetUnion input # set ControlNetUnion input
result_images = [] result_images = []
for sub_img, crops_coords in zip(images, crops_coords_list): for sub_img, crops_coords in zip(images, crops_coords_list):
union_input = ControlNetUnionInputProMax(
tile=sub_img,
)
new_width, new_height = W, H new_width, new_height = W, H
out = pipe( out = pipe(
prompt=[prompt] * 1, prompt=[prompt] * 1,
image=sub_img, image=sub_img,
control_image_list=union_input, control_image=[sub_img],
control_mode=[6],
width=new_width, width=new_width,
height=new_height, height=new_height,
num_inference_steps=30, num_inference_steps=30,
...@@ -1065,7 +1062,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1065,7 +1062,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, control_image: PipelineImageInput = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 0.8, strength: float = 0.8,
...@@ -1090,6 +1087,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1090,6 +1087,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None, target_size: Tuple[int, int] = None,
...@@ -1119,10 +1117,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1119,10 +1117,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again. image latents as `image`, if passing latents directly, it will not be encoded again.
control_image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): control_image (`PipelineImageInput`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`)::
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
...@@ -1291,34 +1286,31 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1291,34 +1286,31 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(control_image_list) != controlnet.config.num_control_type:
if isinstance(control_image_list, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(control_image_list, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
)
# align format for control guidance # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start] control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end] control_guidance_end = len(control_guidance_start) * [control_guidance_end]
# 1. Check inputs. Raise error if not correct if not isinstance(control_image, list):
control_type = [] control_image = [control_image]
for image_type in control_image_list:
if control_image_list[image_type]: if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = [0 for _ in range(num_control_type)]
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2, prompt_2,
control_image_list[image_type], _image,
strength, strength,
num_inference_steps, num_inference_steps,
callback_steps, callback_steps,
...@@ -1335,9 +1327,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1335,9 +1327,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
control_guidance_end, control_guidance_end,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
) )
control_type.append(1)
else:
control_type.append(0)
control_type = torch.Tensor(control_type) control_type = torch.Tensor(control_type)
...@@ -1397,10 +1386,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1397,10 +1386,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 4. Prepare image and controlnet_conditioning_image # 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
for image_type in control_image_list: for idx, _ in enumerate(control_image):
if control_image_list[image_type]: control_image[idx] = self.prepare_control_image(
control_image = self.prepare_control_image( image=control_image[idx],
image=control_image_list[image_type],
width=width, width=width,
height=height, height=height,
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
...@@ -1410,8 +1398,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1410,8 +1398,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
do_classifier_free_guidance=self.do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
height, width = control_image.shape[-2:] height, width = control_image[idx].shape[-2:]
control_image_list[image_type] = control_image
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -1444,10 +1431,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1444,10 +1431,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
) )
# 7.2 Prepare added time ids & embeddings # 7.2 Prepare added time ids & embeddings
for image_type in control_image_list: original_size = original_size or (height, width)
if isinstance(control_image_list[image_type], torch.Tensor):
original_size = original_size or control_image_list[image_type].shape[-2:]
target_size = target_size or (height, width) target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
if negative_original_size is None: if negative_original_size is None:
negative_original_size = original_size negative_original_size = original_size
...@@ -1531,8 +1519,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1531,8 +1519,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
control_model_input, control_model_input,
t, t,
encoder_hidden_states=controlnet_prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image_list, controlnet_cond=control_image,
control_type=control_type, control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale, conditioning_scale=cond_scale,
guess_mode=guess_mode, guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs, added_cond_kwargs=controlnet_added_cond_kwargs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment