Unverified Commit bb1b0fa1 authored by Subho Ghosh's avatar Subho Ghosh Committed by GitHub
Browse files

Feature flux controlnet img2img and inpaint pipeline (#9408)

* Implemented FLUX controlnet support to Img2Img pipeline
parent 8fcfb2a4
...@@ -175,3 +175,16 @@ image.save("flux-fp8-dev.png") ...@@ -175,3 +175,16 @@ image.save("flux-fp8-dev.png")
[[autodoc]] FluxInpaintPipeline [[autodoc]] FluxInpaintPipeline
- all - all
- __call__ - __call__
## FluxControlNetInpaintPipeline
[[autodoc]] FluxControlNetInpaintPipeline
- all
- __call__
## FluxControlNetImg2ImgPipeline
[[autodoc]] FluxControlNetImg2ImgPipeline
- all
- __call__
...@@ -259,6 +259,8 @@ else: ...@@ -259,6 +259,8 @@ else:
"CogVideoXPipeline", "CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline", "CogVideoXVideoToVideoPipeline",
"CycleDiffusionPipeline", "CycleDiffusionPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline", "FluxControlNetPipeline",
"FluxImg2ImgPipeline", "FluxImg2ImgPipeline",
"FluxInpaintPipeline", "FluxInpaintPipeline",
...@@ -708,6 +710,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -708,6 +710,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXPipeline, CogVideoXPipeline,
CogVideoXVideoToVideoPipeline, CogVideoXVideoToVideoPipeline,
CycleDiffusionPipeline, CycleDiffusionPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline, FluxControlNetPipeline,
FluxImg2ImgPipeline, FluxImg2ImgPipeline,
FluxInpaintPipeline, FluxInpaintPipeline,
......
...@@ -127,6 +127,8 @@ else: ...@@ -127,6 +127,8 @@ else:
] ]
_import_structure["flux"] = [ _import_structure["flux"] = [
"FluxControlNetPipeline", "FluxControlNetPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxImg2ImgPipeline", "FluxImg2ImgPipeline",
"FluxInpaintPipeline", "FluxInpaintPipeline",
"FluxPipeline", "FluxPipeline",
...@@ -505,7 +507,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -505,7 +507,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionTextToImagePipeline, VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline, VQDiffusionPipeline,
) )
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline from .flux import (
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
)
from .hunyuandit import HunyuanDiTPipeline from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import ( from .kandinsky import (
......
...@@ -24,6 +24,8 @@ except OptionalDependencyNotAvailable: ...@@ -24,6 +24,8 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["pipeline_flux"] = ["FluxPipeline"] _import_structure["pipeline_flux"] = ["FluxPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -35,6 +37,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -35,6 +37,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .pipeline_flux import FluxPipeline from .pipeline_flux import FluxPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline
else: else:
......
...@@ -332,6 +332,36 @@ class CycleDiffusionPipeline(metaclass=DummyObject): ...@@ -332,6 +332,36 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxControlNetInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxControlNetPipeline(metaclass=DummyObject): class FluxControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
import gc
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxControlNetImg2ImgPipeline,
FluxControlNetModel,
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxControlNetImg2ImgPipeline
params = frozenset(
[
"prompt",
"image",
"control_image",
"height",
"width",
"strength",
"guidance_scale",
"controlnet_conditioning_scale",
"prompt_embeds",
"pooled_prompt_embeds",
]
)
batch_params = frozenset(["prompt", "image", "control_image"])
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
torch.manual_seed(0)
controlnet = FluxControlNetModel(
in_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"controlnet": controlnet,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
image = torch.randn(1, 3, 32, 32).to(device)
control_image = torch.randn(1, 3, 32, 32).to(device)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"control_image": control_image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"controlnet_conditioning_scale": 1.0,
"strength": 0.8,
"height": 32,
"width": 32,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
def test_flux_controlnet_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_flux_controlnet_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
prompt,
prompt_2=None,
device=torch_device,
max_sequence_length=inputs["max_sequence_length"],
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
@slow
@require_torch_gpu
class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetImg2ImgPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
image = torch.randn(1, 3, 64, 64).to(device)
control_image = torch.randn(1, 3, 64, 64).to(device)
return {
"prompt": "A photo of a cat",
"image": image,
"control_image": control_image,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"controlnet_conditioning_scale": 1.0,
"strength": 0.8,
"output_type": "np",
"generator": generator,
}
@unittest.skip("We cannot run inference on this model with the current CI hardware")
def test_flux_controlnet_img2img_inference(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
[0.36132812, 0.30004883, 0.25830078],
[0.36669922, 0.31103516, 0.23754883],
[0.34814453, 0.29248047, 0.23583984],
[0.35791016, 0.30981445, 0.23999023],
[0.36328125, 0.31274414, 0.2607422],
[0.37304688, 0.32177734, 0.26171875],
[0.3671875, 0.31933594, 0.25756836],
[0.36035156, 0.31103516, 0.2578125],
[0.3857422, 0.33789062, 0.27563477],
[0.3701172, 0.31982422, 0.265625],
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
import random
import unittest
import numpy as np
import torch
# torch_device, # {{ edit_1 }} Removed unused import
from transformers import (
AutoTokenizer,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
)
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxControlNetInpaintPipeline,
FluxControlNetModel,
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class FluxControlNetInpaintPipelineTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxControlNetInpaintPipeline
params = frozenset(
[
"prompt",
"height",
"width",
"guidance_scale",
"prompt_embeds",
"pooled_prompt_embeds",
"image",
"mask_image",
"control_image",
"strength",
"num_inference_steps",
"controlnet_conditioning_scale",
]
)
batch_params = frozenset(["prompt", "image", "mask_image", "control_image"])
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=8,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=2,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
torch.manual_seed(0)
controlnet = FluxControlNetModel(
patch_size=1,
in_channels=8,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"controlnet": controlnet,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
mask_image = torch.ones((1, 1, 32, 32)).to(device)
control_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"mask_image": mask_image,
"control_image": control_image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"max_sequence_length": 48,
"strength": 0.8,
"output_type": "np",
}
return inputs
def test_flux_controlnet_inpaint_with_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_images_per_prompt"] = 2
output = pipe(**inputs)
images = output.images
assert images.shape == (2, 32, 32, 3)
def test_flux_controlnet_inpaint_with_controlnet_conditioning_scale(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output_default = pipe(**inputs)
image_default = output_default.images
inputs["controlnet_conditioning_scale"] = 0.5
output_scaled = pipe(**inputs)
image_scaled = output_scaled.images
# Ensure that changing the controlnet_conditioning_scale produces a different output
assert not np.allclose(image_default, image_scaled, atol=0.01)
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment