Unverified Commit 76a62ac9 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[ControlnetUnion] Multiple Fixes (#11888)



fixes

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 1c6ab9e9
...@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
condition = self.controlnet_cond_embedding(cond) condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3)) feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx] feat_seq = feat_seq + self.task_embedding[control_idx]
if from_multi: if from_multi or len(control_type_idx) == 1:
inputs.append(feat_seq.unsqueeze(1)) inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition) condition_list.append(condition)
else: else:
...@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale): for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx]) alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1) alpha = alpha.unsqueeze(-1).unsqueeze(-1)
if from_multi: if from_multi or len(control_type_idx) == 1:
controlnet_cond_fuser += condition + alpha controlnet_cond_fuser += condition + alpha
else: else:
controlnet_cond_fuser += condition + alpha * scale controlnet_cond_fuser += condition + alpha * scale
...@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# 6. scaling # 6. scaling
if guess_mode and not self.config.global_pool_conditions: if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
if from_multi: if from_multi or len(control_type_idx) == 1:
scales = scales * conditioning_scale[0] scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
elif from_multi: elif from_multi or len(control_type_idx) == 1:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples] down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0] mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
......
...@@ -1452,17 +1452,21 @@ class StableDiffusionXLControlNetUnionPipeline( ...@@ -1452,17 +1452,21 @@ class StableDiffusionXLControlNetUnionPipeline(
is_controlnet_compiled = is_compiled_module(self.controlnet) is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
control_type_repeat_factor = (
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
)
if isinstance(controlnet, ControlNetUnionModel): if isinstance(controlnet, ControlNetUnionModel):
control_type = ( control_type = (
control_type.reshape(1, -1) control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype) .to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1) .repeat(control_type_repeat_factor, 1)
) )
if isinstance(controlnet, MultiControlNetUnionModel): elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [ control_type = [
_control_type.reshape(1, -1) _control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype) .to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1) .repeat(control_type_repeat_factor, 1)
for _control_type in control_type for _control_type in control_type
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment