Unverified Commit 37a5f1b3 authored by hlky's avatar hlky Committed by GitHub
Browse files

Experimental per control type scale for ControlNet Union (#10723)

* ControlNet Union scale

* fix

* universal interface

* from_multi

* from_multi
parent 501d9de7
......@@ -605,12 +605,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor,
control_type_idx: List[int],
conditioning_scale: float = 1.0,
conditioning_scale: Union[float, List[float]] = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
from_multi: bool = False,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
......@@ -647,6 +648,8 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
from_multi (`bool`, defaults to `False`):
Use standard scaling when called from `MultiControlNetUnionModel`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
......@@ -658,6 +661,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
if isinstance(conditioning_scale, float):
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
......@@ -742,12 +748,16 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs = []
condition_list = []
for cond, control_idx in zip(controlnet_cond, control_type_idx):
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
if from_multi:
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
inputs.append(feat_seq.unsqueeze(1) * scale)
condition_list.append(condition * scale)
condition = sample
feat_seq = torch.mean(condition, dim=(2, 3))
......@@ -759,10 +769,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x = layer(x)
controlnet_cond_fuser = sample * 0.0
for idx, condition in enumerate(condition_list[:-1]):
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
controlnet_cond_fuser += condition + alpha
if from_multi:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale
sample = sample + controlnet_cond_fuser
......@@ -806,12 +819,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
if from_multi:
scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
elif from_multi:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
if self.config.global_pool_conditions:
down_block_res_samples = [
......
......@@ -47,9 +47,12 @@ class MultiControlNetUnionModel(ModelMixin):
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
down_block_res_samples, mid_block_res_sample = None, None
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
):
if scale == 0.0:
continue
down_samples, mid_sample = controlnet(
sample=sample,
timestep=timestep,
......@@ -63,12 +66,13 @@ class MultiControlNetUnionModel(ModelMixin):
attention_mask=attention_mask,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
from_multi=True,
guess_mode=guess_mode,
return_dict=return_dict,
)
# merge samples
if i == 0:
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
......
......@@ -757,15 +757,9 @@ class StableDiffusionXLControlNetUnionPipeline(
for images_ in image:
for image_ in images_:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False
# Check `controlnet_conditioning_scale`
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
if isinstance(controlnet, ControlNetUnionModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
......@@ -776,8 +770,6 @@ class StableDiffusionXLControlNetUnionPipeline(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
......@@ -808,8 +800,6 @@ class StableDiffusionXLControlNetUnionPipeline(
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
else:
assert False
# Equal number of `image` and `control_mode` elements
if isinstance(controlnet, ControlNetUnionModel):
......@@ -823,8 +813,6 @@ class StableDiffusionXLControlNetUnionPipeline(
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
raise ValueError("Expected len(control_image) == len(control_mode)")
else:
assert False
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
......@@ -1201,28 +1189,33 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(control_image, list):
control_image = [control_image]
else:
control_image = control_image.copy()
if not isinstance(control_mode, list):
control_mode = [control_mode]
if isinstance(controlnet, MultiControlNetUnionModel):
control_image = [[item] for item in control_image]
control_mode = [[item] for item in control_mode]
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if not isinstance(control_image, list):
control_image = [control_image]
else:
control_image = control_image.copy()
if not isinstance(control_mode, list):
control_mode = [control_mode]
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
if isinstance(controlnet_conditioning_scale, float):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
# 1. Check inputs
self.check_inputs(
......@@ -1357,9 +1350,6 @@ class StableDiffusionXLControlNetUnionPipeline(
control_image = control_images
height, width = control_image[0][0].shape[-2:]
else:
assert False
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
......@@ -1397,7 +1387,7 @@ class StableDiffusionXLControlNetUnionPipeline(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps)
controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment