Unverified Commit c1e6a32a authored by 王奇勋's avatar 王奇勋 Committed by GitHub
Browse files

[Flux] Support Union ControlNet (#9175)



* refactor
---------
Co-authored-by: default avatarhaofanwang <haofanwang.ai@gmail.com>
parent 77b21628
...@@ -226,6 +226,8 @@ ...@@ -226,6 +226,8 @@
- sections: - sections:
- local: api/models/controlnet - local: api/models/controlnet
title: ControlNetModel title: ControlNetModel
- local: api/models/controlnet_flux
title: FluxControlNetModel
- local: api/models/controlnet_hunyuandit - local: api/models/controlnet_hunyuandit
title: HunyuanDiT2DControlNetModel title: HunyuanDiT2DControlNetModel
- local: api/models/controlnet_sd3 - local: api/models/controlnet_sd3
...@@ -320,6 +322,8 @@ ...@@ -320,6 +322,8 @@
title: Consistency Models title: Consistency Models
- local: api/pipelines/controlnet - local: api/pipelines/controlnet
title: ControlNet title: ControlNet
- local: api/pipelines/controlnet_flux
title: ControlNet with Flux.1
- local: api/pipelines/controlnet_hunyuandit - local: api/pipelines/controlnet_hunyuandit
title: ControlNet with Hunyuan-DiT title: ControlNet with Hunyuan-DiT
- local: api/pipelines/controlnet_sd3 - local: api/pipelines/controlnet_sd3
......
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# FluxControlNetModel
FluxControlNetModel is an implementation of ControlNet for Flux.1.
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
## Loading from the original format
By default the [`FluxControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`].
```py
from diffusers import FluxControlNetPipeline
from diffusers.models import FluxControlNetModel, FluxMultiControlNetModel
controlnet = FluxControlNetModel.from_pretrained("InstantX/FLUX.1-dev-Controlnet-Canny")
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet)
controlnet = FluxControlNetModel.from_pretrained("InstantX/FLUX.1-dev-Controlnet-Canny")
controlnet = FluxMultiControlNetModel([controlnet])
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet)
```
## FluxControlNetModel
[[autodoc]] FluxControlNetModel
## FluxControlNetOutput
[[autodoc]] models.controlnet_flux.FluxControlNetOutput
\ No newline at end of file
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet with Flux.1
FluxControlNetPipeline is an implementation of ControlNet for Flux.1.
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This controlnet code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for Flux-ControlNet in the table below:
| ControlNet type | Developer | Link |
| -------- | ---------- | ---- |
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny) |
| Depth | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Depth) |
| Union | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union) |
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## FluxControlNetPipeline
[[autodoc]] FluxControlNetPipeline
- all
- __call__
## FluxPipelineOutput
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
\ No newline at end of file
...@@ -554,6 +554,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -554,6 +554,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetXSAdapter, ControlNetXSAdapter,
DiTTransformer2DModel, DiTTransformer2DModel,
FluxControlNetModel, FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel, FluxTransformer2DModel,
HunyuanDiT2DControlNetModel, HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel, HunyuanDiT2DModel,
......
...@@ -35,7 +35,7 @@ if is_torch_available(): ...@@ -35,7 +35,7 @@ if is_torch_available():
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"] _import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_flux"] = ["FluxControlNetModel"] _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"] _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
...@@ -88,7 +88,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -88,7 +88,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel, VQModel,
) )
from .controlnet import ControlNetModel from .controlnet import ControlNetModel
from .controlnet_flux import FluxControlNetModel from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_sparsectrl import SparseControlNetModel from .controlnet_sparsectrl import SparseControlNetModel
......
...@@ -54,6 +54,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -54,6 +54,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
guidance_embeds: bool = False, guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56], axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
): ):
super().__init__() super().__init__()
self.out_channels = in_channels self.out_channels = in_channels
...@@ -101,6 +102,10 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -101,6 +102,10 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for _ in range(len(self.single_transformer_blocks)): for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.union = num_mode is not None
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -173,8 +178,8 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -173,8 +178,8 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
def from_transformer( def from_transformer(
cls, cls,
transformer, transformer,
num_layers=4, num_layers: int = 4,
num_single_layers=10, num_single_layers: int = 10,
attention_head_dim: int = 128, attention_head_dim: int = 128,
num_attention_heads: int = 24, num_attention_heads: int = 24,
load_weights_from_transformer=True, load_weights_from_transformer=True,
...@@ -205,6 +210,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -205,6 +210,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor, controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor = None,
conditioning_scale: float = 1.0, conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None, encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None, pooled_projections: torch.Tensor = None,
...@@ -221,6 +227,12 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -221,6 +227,12 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Args: Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`. Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_mode (`torch.Tensor`):
The mode tensor of shape `(batch_size, 1)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
...@@ -272,6 +284,15 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -272,6 +284,15 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
) )
encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
if txt_ids.ndim == 3: if txt_ids.ndim == 3:
logger.warning( logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated." "Passing `txt_ids` 3d torch.Tensor is deprecated."
...@@ -367,7 +388,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -367,7 +388,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
#
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = ( controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
...@@ -384,3 +404,114 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -384,3 +404,114 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
controlnet_block_samples=controlnet_block_samples, controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples, controlnet_single_block_samples=controlnet_single_block_samples,
) )
class FluxMultiControlNetModel(ModelMixin):
r"""
`FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
compatible with `FluxControlNetModel`.
Args:
controlnets (`List[FluxControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`FluxControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states: torch.FloatTensor,
controlnet_cond: List[torch.tensor],
controlnet_mode: List[torch.tensor],
conditioning_scale: List[float],
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[FluxControlNetOutput, Tuple]:
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1 and self.nets[0].union:
controlnet = self.nets[0]
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
# Regular Multi-ControlNets
# load all ControlNets into memories
else:
for i, (image, mode, scale, controlnet) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
return control_block_samples, control_single_block_samples
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,7 @@ from transformers import ( ...@@ -27,7 +27,7 @@ from transformers import (
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import (
...@@ -61,7 +61,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -61,7 +61,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import FluxControlNetPipeline >>> from diffusers import FluxControlNetPipeline
>>> from diffusers import FluxControlNetModel >>> from diffusers import FluxControlNetModel
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha" >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
>>> pipe = FluxControlNetPipeline.from_pretrained( >>> pipe = FluxControlNetPipeline.from_pretrained(
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
...@@ -195,7 +195,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -195,7 +195,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
text_encoder_2: T5EncoderModel, text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast, tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel, transformer: FluxTransformer2DModel,
controlnet: FluxControlNetModel, controlnet: Union[
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
],
): ):
super().__init__() super().__init__()
...@@ -571,6 +573,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -571,6 +573,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
timesteps: List[int] = None, timesteps: List[int] = None,
guidance_scale: float = 7.0, guidance_scale: float = 7.0,
control_image: PipelineImageInput = None, control_image: PipelineImageInput = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
...@@ -611,6 +614,20 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -611,6 +614,20 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
control_mode (`int` or `List[int]`,, *optional*, defaults to None):
The control mode when applying ControlNet-Union.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
...@@ -730,6 +747,55 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -730,6 +747,55 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
width_control_image, width_control_image,
) )
# set control mode
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
for control_image_ in control_image:
control_image_ = self.prepare_image(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
)
height, width = control_image_.shape[-2:]
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
control_image = control_images
# set control mode
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents( latents, latent_image_ids = self.prepare_latents(
...@@ -785,6 +851,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -785,6 +851,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
controlnet_block_samples, controlnet_single_block_samples = self.controlnet( controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents, hidden_states=latents,
controlnet_cond=control_image, controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale, conditioning_scale=controlnet_conditioning_scale,
timestep=timestep / 1000, timestep=timestep / 1000,
guidance=guidance, guidance=guidance,
......
...@@ -197,6 +197,21 @@ class FluxControlNetModel(metaclass=DummyObject): ...@@ -197,6 +197,21 @@ class FluxControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class FluxMultiControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class FluxTransformer2DModel(metaclass=DummyObject): class FluxTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment