"vscode:/vscode.git/clone" did not exist on "b59815bc2b6fbc68c2fa26833aea210061391bb9"
Unverified Commit 14a1b86f authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

Several fixes to Flux ControlNet pipelines (#9472)



* fix flux controlnet pipelines

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent 2b443a5d
...@@ -29,7 +29,14 @@ from .controlnet import ( ...@@ -29,7 +29,14 @@ from .controlnet import (
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetPipeline,
) )
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline from .flux import (
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
)
from .hunyuandit import HunyuanDiTPipeline from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import ( from .kandinsky import (
KandinskyCombinedPipeline, KandinskyCombinedPipeline,
...@@ -128,6 +135,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -128,6 +135,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline), ("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
] ]
) )
...@@ -143,6 +151,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( ...@@ -143,6 +151,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline), ("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
] ]
) )
......
...@@ -729,7 +729,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -729,7 +729,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=dtype, dtype=self.vae.dtype,
) )
height, width = control_image.shape[-2:] height, width = control_image.shape[-2:]
...@@ -763,7 +763,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -763,7 +763,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=dtype, dtype=self.vae.dtype,
) )
height, width = control_image_.shape[-2:] height, width = control_image_.shape[-2:]
...@@ -840,12 +840,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -840,12 +840,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
# handle guidance guidance = (
if self.transformer.config.guidance_embeds: torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
guidance = torch.tensor([guidance_scale], device=device) )
guidance = guidance.expand(latents.shape[0]) guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
else:
guidance = None
# controlnet # controlnet
controlnet_block_samples, controlnet_single_block_samples = self.controlnet( controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
...@@ -863,6 +861,11 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -863,6 +861,11 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
return_dict=False, return_dict=False,
) )
guidance = (
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latents, hidden_states=latents,
timestep=timestep / 1000, timestep=timestep / 1000,
......
...@@ -767,7 +767,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -767,7 +767,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=dtype, dtype=self.vae.dtype,
) )
height, width = control_image.shape[-2:] height, width = control_image.shape[-2:]
...@@ -798,7 +798,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -798,7 +798,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=dtype, dtype=self.vae.dtype,
) )
height, width = control_image_.shape[-2:] height, width = control_image_.shape[-2:]
......
...@@ -899,7 +899,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -899,7 +899,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=dtype, dtype=self.vae.dtype,
) )
height, width = control_image.shape[-2:] height, width = control_image.shape[-2:]
...@@ -933,7 +933,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -933,7 +933,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
batch_size=batch_size * num_images_per_prompt, batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=dtype, dtype=self.vae.dtype,
) )
height, width = control_image_.shape[-2:] height, width = control_image_.shape[-2:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment