Unverified Commit 37a5f1b3 authored by hlky's avatar hlky Committed by GitHub
Browse files

Experimental per control type scale for ControlNet Union (#10723)

* ControlNet Union scale

* fix

* universal interface

* from_multi

* from_multi
parent 501d9de7
...@@ -605,12 +605,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -605,12 +605,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
controlnet_cond: List[torch.Tensor], controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor, control_type: torch.Tensor,
control_type_idx: List[int], control_type_idx: List[int],
conditioning_scale: float = 1.0, conditioning_scale: Union[float, List[float]] = 1.0,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
from_multi: bool = False,
guess_mode: bool = False, guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
...@@ -647,6 +648,8 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -647,6 +648,8 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Additional conditions for the Stable Diffusion XL UNet. Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`. A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
from_multi (`bool`, defaults to `False`):
Use standard scaling when called from `MultiControlNetUnionModel`.
guess_mode (`bool`, defaults to `False`): guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
...@@ -658,6 +661,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -658,6 +661,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor. returned where the first element is the sample tensor.
""" """
if isinstance(conditioning_scale, float):
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
# check channel order # check channel order
channel_order = self.config.controlnet_conditioning_channel_order channel_order = self.config.controlnet_conditioning_channel_order
...@@ -742,12 +748,16 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -742,12 +748,16 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs = [] inputs = []
condition_list = [] condition_list = []
for cond, control_idx in zip(controlnet_cond, control_type_idx): for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
condition = self.controlnet_cond_embedding(cond) condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3)) feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx] feat_seq = feat_seq + self.task_embedding[control_idx]
inputs.append(feat_seq.unsqueeze(1)) if from_multi:
condition_list.append(condition) inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
inputs.append(feat_seq.unsqueeze(1) * scale)
condition_list.append(condition * scale)
condition = sample condition = sample
feat_seq = torch.mean(condition, dim=(2, 3)) feat_seq = torch.mean(condition, dim=(2, 3))
...@@ -759,10 +769,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -759,10 +769,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x = layer(x) x = layer(x)
controlnet_cond_fuser = sample * 0.0 controlnet_cond_fuser = sample * 0.0
for idx, condition in enumerate(condition_list[:-1]): for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx]) alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1) alpha = alpha.unsqueeze(-1).unsqueeze(-1)
controlnet_cond_fuser += condition + alpha if from_multi:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale
sample = sample + controlnet_cond_fuser sample = sample + controlnet_cond_fuser
...@@ -806,12 +819,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -806,12 +819,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# 6. scaling # 6. scaling
if guess_mode and not self.config.global_pool_conditions: if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale if from_multi:
scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else: elif from_multi:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
if self.config.global_pool_conditions: if self.config.global_pool_conditions:
down_block_res_samples = [ down_block_res_samples = [
......
...@@ -47,9 +47,12 @@ class MultiControlNetUnionModel(ModelMixin): ...@@ -47,9 +47,12 @@ class MultiControlNetUnionModel(ModelMixin):
guess_mode: bool = False, guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]: ) -> Union[ControlNetOutput, Tuple]:
down_block_res_samples, mid_block_res_sample = None, None
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate( for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets) zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
): ):
if scale == 0.0:
continue
down_samples, mid_sample = controlnet( down_samples, mid_sample = controlnet(
sample=sample, sample=sample,
timestep=timestep, timestep=timestep,
...@@ -63,12 +66,13 @@ class MultiControlNetUnionModel(ModelMixin): ...@@ -63,12 +66,13 @@ class MultiControlNetUnionModel(ModelMixin):
attention_mask=attention_mask, attention_mask=attention_mask,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
from_multi=True,
guess_mode=guess_mode, guess_mode=guess_mode,
return_dict=return_dict, return_dict=return_dict,
) )
# merge samples # merge samples
if i == 0: if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else: else:
down_block_res_samples = [ down_block_res_samples = [
......
...@@ -757,15 +757,9 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -757,15 +757,9 @@ class StableDiffusionXLControlNetUnionPipeline(
for images_ in image: for images_ in image:
for image_ in images_: for image_ in images_:
self.check_image(image_, prompt, prompt_embeds) self.check_image(image_, prompt, prompt_embeds)
else:
assert False
# Check `controlnet_conditioning_scale` # Check `controlnet_conditioning_scale`
# TODO Update for https://github.com/huggingface/diffusers/pull/10723 if isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet, ControlNetUnionModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet_conditioning_scale, list): if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale): if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.") raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
...@@ -776,8 +770,6 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -776,8 +770,6 @@ class StableDiffusionXLControlNetUnionPipeline(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets" " the same length as the number of controlnets"
) )
else:
assert False
if len(control_guidance_start) != len(control_guidance_end): if len(control_guidance_start) != len(control_guidance_end):
raise ValueError( raise ValueError(
...@@ -808,8 +800,6 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -808,8 +800,6 @@ class StableDiffusionXLControlNetUnionPipeline(
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type: if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
else:
assert False
# Equal number of `image` and `control_mode` elements # Equal number of `image` and `control_mode` elements
if isinstance(controlnet, ControlNetUnionModel): if isinstance(controlnet, ControlNetUnionModel):
...@@ -823,8 +813,6 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -823,8 +813,6 @@ class StableDiffusionXLControlNetUnionPipeline(
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode): elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
raise ValueError("Expected len(control_image) == len(control_mode)") raise ValueError("Expected len(control_image) == len(control_mode)")
else:
assert False
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
...@@ -1201,28 +1189,33 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1201,28 +1189,33 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(control_image, list):
control_image = [control_image]
else:
control_image = control_image.copy()
if not isinstance(control_mode, list):
control_mode = [control_mode]
if isinstance(controlnet, MultiControlNetUnionModel):
control_image = [[item] for item in control_image]
control_mode = [[item] for item in control_mode]
# align format for control guidance # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start] control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end] control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1 mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = ( control_guidance_start, control_guidance_end = (
mult * [control_guidance_start], mult * [control_guidance_start],
mult * [control_guidance_end], mult * [control_guidance_end],
) )
if not isinstance(control_image, list): if isinstance(controlnet_conditioning_scale, float):
control_image = [control_image] mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
else: controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
control_image = control_image.copy()
if not isinstance(control_mode, list):
control_mode = [control_mode]
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
# 1. Check inputs # 1. Check inputs
self.check_inputs( self.check_inputs(
...@@ -1357,9 +1350,6 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1357,9 +1350,6 @@ class StableDiffusionXLControlNetUnionPipeline(
control_image = control_images control_image = control_images
height, width = control_image[0][0].shape[-2:] height, width = control_image[0][0].shape[-2:]
else:
assert False
# 5. Prepare timesteps # 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas self.scheduler, num_inference_steps, device, timesteps, sigmas
...@@ -1397,7 +1387,7 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1397,7 +1387,7 @@ class StableDiffusionXLControlNetUnionPipeline(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end) for s, e in zip(control_guidance_start, control_guidance_end)
] ]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps) controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings # 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width) original_size = original_size or (height, width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment