Unverified Commit 14a1b86f authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

Several fixes to Flux ControlNet pipelines (#9472)



* fix flux controlnet pipelines

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent 2b443a5d
......@@ -29,7 +29,14 @@ from .controlnet import (
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
from .flux import (
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
......@@ -128,6 +135,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
]
)
......@@ -143,6 +151,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
]
)
......
......@@ -729,7 +729,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
......@@ -763,7 +763,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]
......@@ -840,12 +840,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=device)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
# controlnet
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
......@@ -863,6 +861,11 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
return_dict=False,
)
guidance = (
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
......
......@@ -767,7 +767,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
......@@ -798,7 +798,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]
......
......@@ -899,7 +899,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
......@@ -933,7 +933,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment