Unverified Commit 3e9a28a8 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[authored by @Anghellia) Add support of Xlabs Controlnets #9638 (#9687)



* Add support of Xlabs Controlnets


---------
Co-authored-by: default avatarAnzhella Pankratova <son0shad@gmail.com>
parent 2ffbb88f
...@@ -23,7 +23,7 @@ from ..loaders import PeftAdapterMixin ...@@ -23,7 +23,7 @@ from ..loaders import PeftAdapterMixin
from ..models.attention_processor import AttentionProcessor from ..models.attention_processor import AttentionProcessor
from ..models.modeling_utils import ModelMixin from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from .modeling_outputs import Transformer2DModelOutput from .modeling_outputs import Transformer2DModelOutput
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
...@@ -55,6 +55,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -55,6 +55,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
guidance_embeds: bool = False, guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56], axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None, num_mode: int = None,
conditioning_embedding_channels: int = None,
): ):
super().__init__() super().__init__()
self.out_channels = in_channels self.out_channels = in_channels
...@@ -106,7 +107,14 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -106,7 +107,14 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
if self.union: if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) if conditioning_embedding_channels is not None:
self.input_hint_block = ControlNetConditioningEmbedding(
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
)
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
else:
self.input_hint_block = None
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -269,6 +277,16 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -269,6 +277,16 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
) )
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
if self.input_hint_block is not None:
controlnet_cond = self.input_hint_block(controlnet_cond)
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
height = height_pw // self.config.patch_size
width = width_pw // self.config.patch_size
controlnet_cond = controlnet_cond.reshape(
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
)
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
# add # add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
......
...@@ -402,6 +402,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -402,6 +402,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
controlnet_block_samples=None, controlnet_block_samples=None,
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
return_dict: bool = True, return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
""" """
The [`FluxTransformer2DModel`] forward method. The [`FluxTransformer2DModel`] forward method.
...@@ -508,7 +509,13 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -508,7 +509,13 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
if controlnet_block_samples is not None: if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control)) interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] # For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
......
...@@ -754,19 +754,22 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -754,19 +754,22 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
) )
height, width = control_image.shape[-2:] height, width = control_image.shape[-2:]
# vae encode # xlab controlnet has a input_hint_block and instantx controlnet does not
control_image = self.vae.encode(control_image).latent_dist.sample() controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor if self.controlnet.input_hint_block is None:
# vae encode
# pack control_image = self.vae.encode(control_image).latent_dist.sample()
height_control_image, width_control_image = control_image.shape[2:] control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
control_image = self._pack_latents(
control_image, # pack
batch_size * num_images_per_prompt, height_control_image, width_control_image = control_image.shape[2:]
num_channels_latents, control_image = self._pack_latents(
height_control_image, control_image,
width_control_image, batch_size * num_images_per_prompt,
) num_channels_latents,
height_control_image,
width_control_image,
)
# Here we ensure that `control_mode` has the same length as the control_image. # Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None: if control_mode is not None:
...@@ -777,8 +780,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -777,8 +780,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
elif isinstance(self.controlnet, FluxMultiControlNetModel): elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = [] control_images = []
# xlab controlnet has a input_hint_block and instantx controlnet does not
for control_image_ in control_image: controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
for i, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image( control_image_ = self.prepare_image(
image=control_image_, image=control_image_,
width=width, width=width,
...@@ -790,20 +794,20 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -790,20 +794,20 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
) )
height, width = control_image_.shape[-2:] height, width = control_image_.shape[-2:]
# vae encode if self.controlnet.nets[0].input_hint_block is None:
control_image_ = self.vae.encode(control_image_).latent_dist.sample() # vae encode
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image_.shape[2:] # pack
control_image_ = self._pack_latents( height_control_image, width_control_image = control_image_.shape[2:]
control_image_, control_image_ = self._pack_latents(
batch_size * num_images_per_prompt, control_image_,
num_channels_latents, batch_size * num_images_per_prompt,
height_control_image, num_channels_latents,
width_control_image, height_control_image,
) width_control_image,
)
control_images.append(control_image_) control_images.append(control_image_)
control_image = control_images control_image = control_images
...@@ -927,6 +931,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -927,6 +931,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
img_ids=latent_image_ids, img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs, joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False, return_dict=False,
controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0] )[0]
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment