"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "910d8c842a3baefb0ea31cca6b605f1a72f046f4"
Unverified Commit 68bd6934 authored by Daniel Regado's avatar Daniel Regado Committed by GitHub
Browse files

IP-Adapter support for `StableDiffusion3ControlNetPipeline` (#10363)



* IP-Adapter support for `StableDiffusion3ControlNetPipeline`

* Update src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent f4fdb3a0
...@@ -17,14 +17,16 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -17,14 +17,16 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import ( from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
PreTrainedModel,
T5EncoderModel, T5EncoderModel,
T5TokenizerFast, T5TokenizerFast,
) )
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from ...models.transformers import SD3Transformer2DModel from ...models.transformers import SD3Transformer2DModel
...@@ -138,7 +140,9 @@ def retrieve_timesteps( ...@@ -138,7 +140,9 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): class StableDiffusion3ControlNetPipeline(
DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin
):
r""" r"""
Args: Args:
transformer ([`SD3Transformer2DModel`]): transformer ([`SD3Transformer2DModel`]):
...@@ -174,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -174,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
Provides additional conditioning to the `unet` during the denoising process. If you set multiple Provides additional conditioning to the `unet` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning. additional conditioning.
image_encoder (`PreTrainedModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*):
Image processor for IP Adapter.
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = [] _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__( def __init__(
...@@ -194,6 +202,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -194,6 +202,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
controlnet: Union[ controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
], ],
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
): ):
super().__init__() super().__init__()
if isinstance(controlnet, (list, tuple)): if isinstance(controlnet, (list, tuple)):
...@@ -223,6 +233,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -223,6 +233,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
controlnet=controlnet, controlnet=controlnet,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
...@@ -727,6 +739,84 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -727,6 +739,84 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
"""Encodes the given image into a feature representation using a pre-trained image encoder.
Args:
image (`PipelineImageInput`):
Input image to be encoded.
device: (`torch.device`):
Torch device.
Returns:
`torch.Tensor`: The encoded image feature representation.
"""
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=self.dtype)
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> torch.Tensor:
"""Prepares image embeddings for use in the IP-Adapter.
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
Args:
ip_adapter_image (`PipelineImageInput`, *optional*):
The input image to extract features from for IP-Adapter.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Precomputed image embeddings.
device: (`torch.device`, *optional*):
Torch device.
num_images_per_prompt (`int`, defaults to 1):
Number of images that should be generated per prompt.
do_classifier_free_guidance (`bool`, defaults to True):
Whether to use classifier free guidance or not.
"""
device = device or self._execution_device
if ip_adapter_image_embeds is not None:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
else:
single_image_embeds = ip_adapter_image_embeds
elif ip_adapter_image is not None:
single_image_embeds = self.encode_image(ip_adapter_image, device)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
else:
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
return image_embeds.to(device=device)
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, *args, **kwargs):
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
logger.warning(
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
)
super().enable_sequential_cpu_offload(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -754,6 +844,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -754,6 +844,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -843,6 +935,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -843,6 +935,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -1040,7 +1138,22 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -1040,7 +1138,22 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
# SD35 official 8b controlnet does not use encoder_hidden_states # SD35 official 8b controlnet does not use encoder_hidden_states
controlnet_encoder_hidden_states = None controlnet_encoder_hidden_states = None
# 7. Denoising loop # 7. Prepare image embeddings
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
else:
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt: if self.interrupt:
......
...@@ -870,7 +870,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -870,7 +870,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`torch.Tensor`, *optional*): ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
......
...@@ -150,6 +150,8 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ...@@ -150,6 +150,8 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
"transformer": transformer, "transformer": transformer,
"vae": vae, "vae": vae,
"controlnet": controlnet, "controlnet": controlnet,
"image_encoder": None,
"feature_extractor": None,
} }
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment