Unverified Commit 31912484 authored by Daniel Regado's avatar Daniel Regado Committed by GitHub
Browse files

[WIP] SD3.5 IP-Adapter Pipeline Integration (#9987)




* Added support for single IPAdapter on SD3.5 pipeline



---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 648d968c
...@@ -238,6 +238,8 @@ ...@@ -238,6 +238,8 @@
title: Textual Inversion title: Textual Inversion
- local: api/loaders/unet - local: api/loaders/unet
title: UNet title: UNet
- local: api/loaders/transformer_sd3
title: SD3Transformer2D
- local: api/loaders/peft - local: api/loaders/peft
title: PEFT title: PEFT
title: Loaders title: Loaders
......
...@@ -86,6 +86,8 @@ An attention processor is a class for applying different types of attention mech ...@@ -86,6 +86,8 @@ An attention processor is a class for applying different types of attention mech
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0 [[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0
## JointAttnProcessor2_0 ## JointAttnProcessor2_0
[[autodoc]] models.attention_processor.JointAttnProcessor2_0 [[autodoc]] models.attention_processor.JointAttnProcessor2_0
......
...@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading] ...@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
[[autodoc]] loaders.ip_adapter.IPAdapterMixin [[autodoc]] loaders.ip_adapter.IPAdapterMixin
## SD3IPAdapterMixin
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
- all
- is_ip_adapter_active
## IPAdapterMaskProcessor ## IPAdapterMaskProcessor
[[autodoc]] image_processor.IPAdapterMaskProcessor [[autodoc]] image_processor.IPAdapterMaskProcessor
\ No newline at end of file
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SD3Transformer2D
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
<Tip>
To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
</Tip>
## SD3Transformer2DLoadersMixin
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
- all
- _load_ip_adapter_weights
\ No newline at end of file
...@@ -59,9 +59,76 @@ image.save("sd3_hello_world.png") ...@@ -59,9 +59,76 @@ image.save("sd3_hello_world.png")
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large) - [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo) - [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
## Image Prompting with IP-Adapters
An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:
- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.
- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.
- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection.
IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.
```python
import torch
from PIL import Image
from diffusers import StableDiffusion3Pipeline
from transformers import SiglipVisionModel, SiglipImageProcessor
image_encoder_id = "google/siglip-so400m-patch14-384"
ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
feature_extractor = SiglipImageProcessor.from_pretrained(
image_encoder_id,
torch_dtype=torch.float16
)
image_encoder = SiglipVisionModel.from_pretrained(
image_encoder_id,
torch_dtype=torch.float16
).to( "cuda")
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large",
torch_dtype=torch.float16,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
).to("cuda")
pipe.load_ip_adapter(ip_adapter_id)
pipe.set_ip_adapter_scale(0.6)
ref_img = Image.open("image.jpg").convert('RGB')
image = pipe(
width=1024,
height=1024,
prompt="a cat",
negative_prompt="lowres, low quality, worst quality",
num_inference_steps=24,
guidance_scale=5.0,
ip_adapter_image=ref_img
).images[0]
image.save("result.jpg")
```
<div class="justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd3_ip_adapter_example.png"/>
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "a cat"</figcaption>
</div>
<Tip>
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
</Tip>
## Memory Optimisations for SD3 ## Memory Optimisations for SD3
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
### Running Inference with Model Offloading ### Running Inference with Model Offloading
......
...@@ -56,6 +56,7 @@ _import_structure = {} ...@@ -56,6 +56,7 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"] _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"] _import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available(): if is_transformers_available():
...@@ -74,7 +75,10 @@ if is_torch_available(): ...@@ -74,7 +75,10 @@ if is_torch_available():
"SanaLoraLoaderMixin", "SanaLoraLoaderMixin",
] ]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"] _import_structure["ip_adapter"] = [
"IPAdapterMixin",
"SD3IPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"] _import_structure["peft"] = ["PeftAdapterMixin"]
...@@ -82,11 +86,15 @@ _import_structure["peft"] = ["PeftAdapterMixin"] ...@@ -82,11 +86,15 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available(): if is_torch_available():
from .single_file_model import FromOriginalModelMixin from .single_file_model import FromOriginalModelMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers from .utils import AttnProcsLayers
if is_transformers_available(): if is_transformers_available():
from .ip_adapter import IPAdapterMixin from .ip_adapter import (
IPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import ( from .lora_pipeline import (
AmusedLoraLoaderMixin, AmusedLoraLoaderMixin,
CogVideoXLoraLoaderMixin, CogVideoXLoraLoaderMixin,
......
...@@ -33,15 +33,18 @@ from .unet_loader_utils import _maybe_expand_lora_scales ...@@ -33,15 +33,18 @@ from .unet_loader_utils import _maybe_expand_lora_scales
if is_transformers_available(): if is_transformers_available():
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
from ..models.attention_processor import ( from ..models.attention_processor import (
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor, IPAdapterXFormersAttnProcessor,
) JointAttnProcessor2_0,
SD3IPAdapterJointAttnProcessor2_0,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -348,3 +351,235 @@ class IPAdapterMixin: ...@@ -348,3 +351,235 @@ class IPAdapterMixin:
else value.__class__() else value.__class__()
) )
self.unet.set_attn_processor(attn_procs) self.unet.set_attn_processor(attn_procs)
class SD3IPAdapterMixin:
"""Mixin for handling StableDiffusion 3 IP Adapters."""
@property
def is_ip_adapter_active(self) -> bool:
"""Checks if IP-Adapter is loaded and scale > 0.
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
the image context is irrelevant.
Returns:
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
"""
scales = [
attn_proc.scale
for attn_proc in self.transformer.attn_processors.values()
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
]
return len(scales) > 0 and any(scale > 0 for scale in scales)
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
weight_name: str = "ip-adapter.safetensors",
subfolder: Optional[str] = None,
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
) -> None:
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
weight_name (`str`, defaults to "ip-adapter.safetensors"):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
subfolder (`str`, *optional*):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
`image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# Load the main state dict first
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
else:
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
# Commons args for loading image encoder and image processor
kwargs = {
"low_cpu_mem_usage": low_cpu_mem_usage,
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
self.register_modules(
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# Load IP-Adapter into transformer
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: float) -> None:
"""
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
the model to produce more diverse images, but they may not be as aligned with the image prompt.
Example:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.set_ip_adapter_scale(0.6)
>>> ...
```
Args:
scale (float):
IP-Adapter scale to be set.
"""
for attn_processor in self.transformer.attn_processors.values():
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
attn_processor.scale = scale
def unload_ip_adapter(self) -> None:
"""
Unloads the IP Adapter weights.
Example:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# Remove image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=None)
# Remove feature extractor
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=None)
# Remove image projection
self.transformer.image_proj = None
# Restore original attention processors layers
attn_procs = {
name: (
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
)
for name, value in self.transformer.attn_processors.items()
}
self.transformer.set_attn_processor(attn_procs)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
class SD3Transformer2DLoadersMixin:
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
Args:
state_dict (`Dict`):
State dict with keys "ip_adapter", which contains parameters for attention processors, and
"image_proj", which contains parameters for image projection net.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# IP-Adapter cross attention parameters
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
# Dict where key is transformer layer index, value is attention processor's state dict
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
for key, weights in state_dict["ip_adapter"].items():
idx, name = key.split(".", maxsplit=1)
layer_state_dict[int(idx)][name] = weights
# Create IP-Adapter attention processor
attn_procs = {}
for idx, name in enumerate(self.attn_processors.keys()):
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size,
ip_hidden_states_dim=ip_hidden_states_dim,
head_dim=self.config.attention_head_dim,
timesteps_emb_dim=timesteps_emb_dim,
).to(self.device, dtype=self.dtype)
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
load_model_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
)
self.set_attn_processor(attn_procs)
# Image projetion parameters
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
num_queries = state_dict["image_proj"]["latents"].shape[1]
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
# Image projection
self.image_proj = IPAdapterTimeImageProjection(
embed_dim=embed_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
heads=heads,
num_queries=num_queries,
timestep_in_dim=timestep_in_dim,
).to(device=self.device, dtype=self.dtype)
if not low_cpu_mem_usage:
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
else:
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
...@@ -5243,6 +5243,177 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module): ...@@ -5243,6 +5243,177 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
return hidden_states return hidden_states
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
additional image-based information and timestep embeddings.
Args:
hidden_size (`int`):
The number of hidden channels.
ip_hidden_states_dim (`int`):
The image feature dimension.
head_dim (`int`):
The number of head channels.
timesteps_emb_dim (`int`, defaults to 1280):
The number of input channels for timestep embedding.
scale (`float`, defaults to 0.5):
IP-Adapter scale.
"""
def __init__(
self,
hidden_size: int,
ip_hidden_states_dim: int,
head_dim: int,
timesteps_emb_dim: int = 1280,
scale: float = 0.5,
):
super().__init__()
# To prevent circular import
from .normalization import AdaLayerNorm, RMSNorm
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.norm_q = RMSNorm(head_dim, 1e-6)
self.norm_k = RMSNorm(head_dim, 1e-6)
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
self.scale = scale
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
ip_hidden_states: torch.FloatTensor = None,
temb: torch.FloatTensor = None,
) -> torch.FloatTensor:
"""
Perform the attention computation, integrating image features (if provided) and timestep embeddings.
If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
Args:
attn (`Attention`):
Attention instance.
hidden_states (`torch.FloatTensor`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor`, *optional*):
The encoder hidden states.
attention_mask (`torch.FloatTensor`, *optional*):
Attention mask.
ip_hidden_states (`torch.FloatTensor`, *optional*):
Image embeddings.
temb (`torch.FloatTensor`, *optional*):
Timestep embeddings.
Returns:
`torch.FloatTensor`: Output hidden states.
"""
residual = hidden_states
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
img_query = query
img_key = key
img_value = value
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP Adapter
if self.scale != 0 and ip_hidden_states is not None:
# Norm image features
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
# To k and v
ip_key = self.to_k_ip(norm_ip_hidden_states)
ip_value = self.to_v_ip(norm_ip_hidden_states)
# Reshape
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Norm
query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
ip_key = self.norm_ip_k(ip_key)
# cat img
key = torch.cat([img_key, ip_key], dim=2)
value = torch.cat([img_value, ip_value], dim=2)
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + ip_hidden_states * self.scale
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class PAGIdentitySelfAttnProcessor2_0: class PAGIdentitySelfAttnProcessor2_0:
r""" r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
...@@ -5772,6 +5943,7 @@ AttentionProcessor = Union[ ...@@ -5772,6 +5943,7 @@ AttentionProcessor = Union[
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor, IPAdapterXFormersAttnProcessor,
SD3IPAdapterJointAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0,
LoRAAttnProcessor, LoRAAttnProcessor,
......
...@@ -2396,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module): ...@@ -2396,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
return out return out
class IPAdapterTimeImageProjectionBlock(nn.Module):
"""Block for IPAdapterTimeImageProjection.
Args:
hidden_dim (`int`, defaults to 1280):
The number of hidden channels.
dim_head (`int`, defaults to 64):
The number of head channels.
heads (`int`, defaults to 20):
Parallel attention heads.
ffn_ratio (`int`, defaults to 4):
The expansion ratio of feedforward network hidden layer channels.
"""
def __init__(
self,
hidden_dim: int = 1280,
dim_head: int = 64,
heads: int = 20,
ffn_ratio: int = 4,
) -> None:
super().__init__()
from .attention import FeedForward
self.ln0 = nn.LayerNorm(hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim)
self.attn = Attention(
query_dim=hidden_dim,
cross_attention_dim=hidden_dim,
dim_head=dim_head,
heads=heads,
bias=False,
out_bias=False,
)
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
# AdaLayerNorm
self.adaln_silu = nn.SiLU()
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
self.adaln_norm = nn.LayerNorm(hidden_dim)
# Set attention scale and fuse KV
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
self.attn.fuse_projections()
self.attn.to_k = None
self.attn.to_v = None
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x (`torch.Tensor`):
Image features.
latents (`torch.Tensor`):
Latent features.
timestep_emb (`torch.Tensor`):
Timestep embedding.
Returns:
`torch.Tensor`: Output latent features.
"""
# Shift and scale for AdaLayerNorm
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
# Fused Attention
residual = latents
x = self.ln0(x)
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
batch_size = latents.shape[0]
query = self.attn.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.attn.heads
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
latents = weight @ value
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
latents = self.attn.to_out[0](latents)
latents = self.attn.to_out[1](latents)
latents = latents + residual
## FeedForward
residual = latents
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
return self.ff(latents) + residual
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
class IPAdapterTimeImageProjection(nn.Module):
"""Resampler of SD3 IP-Adapter with timestep embedding.
Args:
embed_dim (`int`, defaults to 1152):
The feature dimension.
output_dim (`int`, defaults to 2432):
The number of output channels.
hidden_dim (`int`, defaults to 1280):
The number of hidden channels.
depth (`int`, defaults to 4):
The number of blocks.
dim_head (`int`, defaults to 64):
The number of head channels.
heads (`int`, defaults to 20):
Parallel attention heads.
num_queries (`int`, defaults to 64):
The number of queries.
ffn_ratio (`int`, defaults to 4):
The expansion ratio of feedforward network hidden layer channels.
timestep_in_dim (`int`, defaults to 320):
The number of input channels for timestep embedding.
timestep_flip_sin_to_cos (`bool`, defaults to True):
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
timestep_freq_shift (`int`, defaults to 0):
Controls the timestep delta between frequencies between dimensions.
"""
def __init__(
self,
embed_dim: int = 1152,
output_dim: int = 2432,
hidden_dim: int = 1280,
depth: int = 4,
dim_head: int = 64,
heads: int = 20,
num_queries: int = 64,
ffn_ratio: int = 4,
timestep_in_dim: int = 320,
timestep_flip_sin_to_cos: bool = True,
timestep_freq_shift: int = 0,
) -> None:
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
self.proj_in = nn.Linear(embed_dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList(
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
)
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass.
Args:
x (`torch.Tensor`):
Image features.
timestep (`torch.Tensor`):
Timestep in denoising process.
Returns:
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
"""
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
timestep_emb = self.time_embedding(timestep_emb)
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
x = x + timestep_emb[:, None]
for block in self.layers:
latents = block(x, latents, timestep_emb)
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents, timestep_emb
class MultiIPAdapterImageProjection(nn.Module): class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__() super().__init__()
......
...@@ -18,7 +18,7 @@ import torch.nn as nn ...@@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import ( from ...models.attention_processor import (
Attention, Attention,
...@@ -103,7 +103,9 @@ class SD3SingleTransformerBlock(nn.Module): ...@@ -103,7 +103,9 @@ class SD3SingleTransformerBlock(nn.Module):
return hidden_states return hidden_states
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): class SD3Transformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
):
""" """
The Transformer model introduced in Stable Diffusion 3. The Transformer model introduced in Stable Diffusion 3.
...@@ -349,8 +351,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -349,8 +351,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
Input `hidden_states`. Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
from the embeddings of input conditions. Embeddings projected from the embeddings of input conditions.
timestep (`torch.LongTensor`): timestep (`torch.LongTensor`):
Used to indicate denoising step. Used to indicate denoising step.
block_controlnet_hidden_states (`list` of `torch.Tensor`): block_controlnet_hidden_states (`list` of `torch.Tensor`):
...@@ -390,6 +392,12 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -390,6 +392,12 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
temb = self.time_text_embed(timestep, pooled_projections) temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
# Skip specified layers # Skip specified layers
is_skip = True if skip_layers is not None and index_block in skip_layers else False is_skip = True if skip_layers is not None and index_block in skip_layers else False
......
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,14 +17,16 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,14 +17,16 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import ( from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
PreTrainedModel,
T5EncoderModel, T5EncoderModel,
T5TokenizerFast, T5TokenizerFast,
) )
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
...@@ -142,7 +144,7 @@ def retrieve_timesteps( ...@@ -142,7 +144,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
r""" r"""
Args: Args:
transformer ([`SD3Transformer2DModel`]): transformer ([`SD3Transformer2DModel`]):
...@@ -174,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -174,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_3 (`T5TokenizerFast`): tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`PreTrainedModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*):
Image processor for IP Adapter.
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = [] _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__( def __init__(
...@@ -191,6 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -191,6 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_2: CLIPTokenizer, tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel, text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast, tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
): ):
super().__init__() super().__init__()
...@@ -204,6 +212,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -204,6 +212,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_3=tokenizer_3, tokenizer_3=tokenizer_3,
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
...@@ -683,6 +693,83 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -683,6 +693,83 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
"""Encodes the given image into a feature representation using a pre-trained image encoder.
Args:
image (`PipelineImageInput`):
Input image to be encoded.
device: (`torch.device`):
Torch device.
Returns:
`torch.Tensor`: The encoded image feature representation.
"""
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=self.dtype)
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> torch.Tensor:
"""Prepares image embeddings for use in the IP-Adapter.
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
Args:
ip_adapter_image (`PipelineImageInput`, *optional*):
The input image to extract features from for IP-Adapter.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Precomputed image embeddings.
device: (`torch.device`, *optional*):
Torch device.
num_images_per_prompt (`int`, defaults to 1):
Number of images that should be generated per prompt.
do_classifier_free_guidance (`bool`, defaults to True):
Whether to use classifier free guidance or not.
"""
device = device or self._execution_device
if ip_adapter_image_embeds is not None:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
else:
single_image_embeds = ip_adapter_image_embeds
elif ip_adapter_image is not None:
single_image_embeds = self.encode_image(ip_adapter_image, device)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
else:
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
return image_embeds.to(device=device)
def enable_sequential_cpu_offload(self, *args, **kwargs):
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
logger.warning(
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
)
super().enable_sequential_cpu_offload(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -705,6 +792,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -705,6 +792,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -713,9 +802,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -713,9 +802,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256, max_sequence_length: int = 256,
skip_guidance_layers: List[int] = None, skip_guidance_layers: List[int] = None,
skip_layer_guidance_scale: int = 2.8, skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_stop: int = 0.2, skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_start: int = 0.01, skip_layer_guidance_start: float = 0.01,
mu: Optional[float] = None, mu: Optional[float] = None,
): ):
r""" r"""
...@@ -781,6 +870,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -781,6 +870,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -938,7 +1032,22 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -938,7 +1032,22 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Denoising loop # 6. Prepare image embeddings
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
else:
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt: if self.interrupt:
......
...@@ -103,6 +103,8 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -103,6 +103,8 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
"tokenizer_3": tokenizer_3, "tokenizer_3": tokenizer_3,
"transformer": transformer, "transformer": transformer,
"vae": vae, "vae": vae,
"image_encoder": None,
"feature_extractor": None,
} }
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment