"vscode:/vscode.git/clone" did not exist on "062f59a8f6e2b7c96aedab0e5810ae75364492cf"
Unverified Commit a4f9c3cb authored by Ishan Modi's avatar Ishan Modi Committed by GitHub
Browse files

[Feature] Added Xlab Controlnet support (#11249)

update
parent 4b60f4b6
......@@ -800,17 +800,20 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
height, width = control_image.shape[-2:]
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
if self.controlnet.input_hint_block is None:
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
......@@ -819,7 +822,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
for control_image_ in control_image:
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
for i, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
......@@ -831,17 +836,18 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
height, width = control_image_.shape[-2:]
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
if self.controlnet.nets[0].input_hint_block is None:
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
......@@ -955,6 +961,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]
latents_dtype = latents.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment