Unverified Commit cc051309 authored by 林金鹏's avatar 林金鹏 Committed by GitHub
Browse files

Support SD3 controlnet inpainting (#9099)



* add controlnet inpainting pipeline

* [SD3] add controlnet inpaint example

* update example and fix code style

* fix code style with ruff

* Update controlnet_sd3.md : add control inpaint pipeline

* Update docs/source/en/api/pipelines/controlnet_sd3.md
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Update docs/source/en/api/pipelines/controlnet_sd3.md
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Update docs/source/en/api/pipelines/controlnet_sd3.md
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Update src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Update __init__.py : add sd3 control pipelines

* Update pipeline : add new param doc & check input reference.

* fix typo

* make style & make quality

* add unittest for sd3 controlnet inpaint

---------
Co-authored-by: default avatar鹏徙 <linjinpeng.ljp@alibaba-inc.com>
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>
parent 15eb77bc
<!--Copyright 2023 The HuggingFace Team and The InstantX Team. All rights reserved. <!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at the License. You may obtain a copy of the License at
...@@ -22,7 +22,16 @@ The abstract from the paper is: ...@@ -22,7 +22,16 @@ The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.* *We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for SD3-ControlNet on [The InstantX Team](https://huggingface.co/InstantX) Hub profile. This controlnet code is mainly implemented by [The InstantX Team](https://huggingface.co/InstantX). The inpainting-related code was developed by [The Alimama Creative Team](https://huggingface.co/alimama-creative). You can find pre-trained checkpoints for SD3-ControlNet in the table below:
| ControlNet type | Developer | Link |
| -------- | ---------- | ---- |
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Canny) |
| Pose | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Pose) |
| Tile | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Tile) |
| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |
<Tip> <Tip>
...@@ -35,5 +44,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) ...@@ -35,5 +44,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all - all
- __call__ - __call__
## StableDiffusion3ControlNetInpaintingPipeline
[[autodoc]] pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetInpaintingPipeline
- all
- __call__
## StableDiffusion3PipelineOutput ## StableDiffusion3PipelineOutput
[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput [[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
...@@ -308,6 +308,7 @@ else: ...@@ -308,6 +308,7 @@ else:
"StableCascadeCombinedPipeline", "StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline", "StableCascadeDecoderPipeline",
"StableCascadePriorPipeline", "StableCascadePriorPipeline",
"StableDiffusion3ControlNetInpaintingPipeline",
"StableDiffusion3ControlNetPipeline", "StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline", "StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline", "StableDiffusion3InpaintPipeline",
......
...@@ -55,6 +55,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -55,6 +55,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
pooled_projection_dim: int = 2048, pooled_projection_dim: int = 2048,
out_channels: int = 16, out_channels: int = 16,
pos_embed_max_size: int = 96, pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0,
): ):
super().__init__() super().__init__()
default_out_channels = in_channels default_out_channels = in_channels
...@@ -98,7 +99,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -98,7 +99,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
height=sample_size, height=sample_size,
width=sample_size, width=sample_size,
patch_size=patch_size, patch_size=patch_size,
in_channels=in_channels, in_channels=in_channels + extra_conditioning_channels,
embed_dim=self.inner_dim, embed_dim=self.inner_dim,
pos_embed_type=None, pos_embed_type=None,
) )
......
...@@ -173,6 +173,7 @@ else: ...@@ -173,6 +173,7 @@ else:
_import_structure["controlnet_sd3"].extend( _import_structure["controlnet_sd3"].extend(
[ [
"StableDiffusion3ControlNetPipeline", "StableDiffusion3ControlNetPipeline",
"StableDiffusion3ControlNetInpaintingPipeline",
] ]
) )
_import_structure["deepfloyd_if"] = [ _import_structure["deepfloyd_if"] = [
...@@ -465,9 +466,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -465,9 +466,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .controlnet_hunyuandit import ( from .controlnet_hunyuandit import (
HunyuanDiTControlNetPipeline, HunyuanDiTControlNetPipeline,
) )
from .controlnet_sd3 import ( from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline
StableDiffusion3ControlNetPipeline,
)
from .controlnet_xs import ( from .controlnet_xs import (
StableDiffusionControlNetXSPipeline, StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline,
......
...@@ -23,6 +23,9 @@ except OptionalDependencyNotAvailable: ...@@ -23,6 +23,9 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"] _import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"]
_import_structure["pipeline_stable_diffusion_3_controlnet_inpainting"] = [
"StableDiffusion3ControlNetInpaintingPipeline"
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
...@@ -33,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -33,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils.dummy_torch_and_transformers_objects import *
else: else:
from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline
from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline
try: try:
if not (is_transformers_available() and is_flax_available()): if not (is_transformers_available() and is_flax_available()):
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
SD3Transformer2DModel,
StableDiffusion3ControlNetInpaintingPipeline,
)
from diffusers.models import SD3ControlNetModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class StableDiffusion3ControlInpaintNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = StableDiffusion3ControlNetInpaintingPipeline
params = frozenset(
[
"prompt",
"height",
"width",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
patch_size=1,
in_channels=8,
num_layers=4,
attention_head_dim=8,
num_attention_heads=4,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
out_channels=8,
)
torch.manual_seed(0)
controlnet = SD3ControlNetModel(
sample_size=32,
patch_size=1,
in_channels=8,
num_layers=1,
attention_head_dim=8,
num_attention_heads=4,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
out_channels=8,
extra_conditioning_channels=1,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=8,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"text_encoder_3": text_encoder_3,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
"controlnet": controlnet,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
control_image = randn_tensor(
(1, 3, 32, 32),
generator=generator,
device=torch.device(device),
dtype=torch.float16,
)
control_mask = randn_tensor(
(1, 1, 32, 32),
generator=generator,
device=torch.device(device),
dtype=torch.float16,
)
controlnet_conditioning_scale = 0.95
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 7.0,
"output_type": "np",
"control_image": control_image,
"control_mask": control_mask,
"controlnet_conditioning_scale": controlnet_conditioning_scale,
}
return inputs
def test_controlnet_inpaint_sd3(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusion3ControlNetInpaintingPipeline(**components)
sd_pipe = sd_pipe.to(torch_device, dtype=torch.float16)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[0.51708984, 0.7421875, 0.4580078, 0.6435547, 0.65625, 0.43603516, 0.5151367, 0.65722656, 0.60839844]
)
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment