"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "4d1d561d28e04beb56f2c75b4ddaaf20d787ba07"
Unverified Commit 9cd37557 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

flux controlnet fix (control_modes batch & others) (#9507)



* flux controlnet mode to take into account batch size

* incorporate yiyixuxu's suggestions (cleaner logic) as well as clean up control mode handling for multi case

* fix

* fix use_guidance when controlnet is a multi and does not have config

---------
Co-authored-by: default avatarChristopher Beckham <christopher.j.beckham@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 1c6ede93
...@@ -502,16 +502,17 @@ class FluxMultiControlNetModel(ModelMixin): ...@@ -502,16 +502,17 @@ class FluxMultiControlNetModel(ModelMixin):
control_block_samples = block_samples control_block_samples = block_samples
control_single_block_samples = single_block_samples control_single_block_samples = single_block_samples
else: else:
control_block_samples = [ if block_samples is not None and control_block_samples is not None:
control_block_sample + block_sample control_block_samples = [
for control_block_sample, block_sample in zip(control_block_samples, block_samples) control_block_sample + block_sample
] for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [ if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_sample + block_sample control_single_block_samples = [
for control_single_block_sample, block_sample in zip( control_single_block_sample + block_sample
control_single_block_samples, single_block_samples for control_single_block_sample, block_sample in zip(
) control_single_block_samples, single_block_samples
] )
]
return control_block_samples, control_single_block_samples return control_block_samples, control_single_block_samples
...@@ -747,10 +747,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -747,10 +747,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
width_control_image, width_control_image,
) )
# set control mode # Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None: if control_mode is not None:
if not isinstance(control_mode, int):
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1]) control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
elif isinstance(self.controlnet, FluxMultiControlNetModel): elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = [] control_images = []
...@@ -785,16 +787,22 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -785,16 +787,22 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
control_image = control_images control_image = control_images
# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode # set control mode
control_mode_ = [] control_modes = []
if isinstance(control_mode, list): for cmode in control_mode:
for cmode in control_mode: if cmode is None:
if cmode is None: cmode = -1
control_mode_.append(-1) control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
else: control_modes.append(control_mode)
control_mode_.append(cmode) control_mode = control_modes
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 num_channels_latents = self.transformer.config.in_channels // 4
...@@ -840,9 +848,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -840,9 +848,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
guidance = ( if isinstance(self.controlnet, FluxMultiControlNetModel):
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None use_guidance = self.controlnet.nets[0].config.guidance_embeds
) else:
use_guidance = self.controlnet.config.guidance_embeds
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
# controlnet # controlnet
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment