Unverified Commit 2e8d18e6 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[IP-Adapter] Support multiple IP-Adapters (#6573)




---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarAlvaro Somoza <somoza.alvaro@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 03373de0
...@@ -506,22 +506,11 @@ import torch ...@@ -506,22 +506,11 @@ import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image from diffusers.utils import load_image
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1
)
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16, torch_dtype=torch.float16,
scheduler=noise_scheduler,
).to("cuda") ).to("cuda")
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin") pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
pipeline.set_ip_adapter_scale(0.7) pipeline.set_ip_adapter_scale(0.7)
...@@ -550,6 +539,66 @@ image = pipeline( ...@@ -550,6 +539,66 @@ image = pipeline(
</div> </div>
</div> </div>
You can load multiple IP-Adapter models and use multiple reference images at the same time. In this example we use IP-Adapter-Plus face model to create a consistent character and also use IP-Adapter-Plus model along with 10 images to create a coherent style in the image we generate.
```python
import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
image_encoder=image_encoder,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"]
)
pipeline.set_ip_adapter_scale([0.7, 0.3])
pipeline.enable_model_cpu_offload()
face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")
style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
style_images = [load_image(f"{style_folder}/img{i}.png") for i in range(10)]
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipeline(
prompt="wonderwoman",
ip_adapter_image=[style_images, face_image],
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50, num_images_per_prompt=1,
generator=generator,
).images[0]
```
<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_style_grid.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">style input image</figcaption>
</div>
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">face input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_multi_out.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
</div>
</div>
### LCM-Lora ### LCM-Lora
You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights. You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
......
...@@ -11,8 +11,9 @@ ...@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, List, Union
import torch import torch
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
...@@ -45,9 +46,9 @@ class IPAdapterMixin: ...@@ -45,9 +46,9 @@ class IPAdapterMixin:
@validate_hf_hub_args @validate_hf_hub_args
def load_ip_adapter( def load_ip_adapter(
self, self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: str, subfolder: Union[str, List[str]],
weight_name: str, weight_name: Union[str, List[str]],
**kwargs, **kwargs,
): ):
""" """
...@@ -87,6 +88,26 @@ class IPAdapterMixin: ...@@ -87,6 +88,26 @@ class IPAdapterMixin:
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
""" """
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first. # Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -100,61 +121,68 @@ class IPAdapterMixin: ...@@ -100,61 +121,68 @@ class IPAdapterMixin:
"file_type": "attn_procs_weights", "file_type": "attn_procs_weights",
"framework": "pytorch", "framework": "pytorch",
} }
state_dicts = []
if not isinstance(pretrained_model_name_or_path_or_dict, dict): for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weight_name, subfolder
pretrained_model_name_or_path_or_dict, ):
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") model_file = _get_model_file(
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(), weights_name=weight_name,
).to(self.device, dtype=self.dtype) cache_dir=cache_dir,
self.image_encoder = image_encoder force_download=force_download,
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else: else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.") state_dict = pretrained_model_name_or_path_or_dict
# create feature extractor if it has not been registered to the pipeline yet keys = list(state_dict.keys())
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: if keys != ["image_proj", "ip_adapter"]:
self.feature_extractor = CLIPImageProcessor() raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])
state_dicts.append(state_dict)
# load ip-adapter into unet
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
feature_extractor = CLIPImageProcessor()
self.register_modules(feature_extractor=feature_extractor)
# load ip-adapter into unet
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dict) unet._load_ip_adapter_weights(state_dicts)
def set_ip_adapter_scale(self, scale): def set_ip_adapter_scale(self, scale):
if not isinstance(scale, list):
scale = [scale]
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values(): for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
......
...@@ -25,7 +25,12 @@ import torch.nn.functional as F ...@@ -25,7 +25,12 @@ import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from torch import nn from torch import nn
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection from ..models.embeddings import (
ImageProjection,
IPAdapterFullImageProjection,
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
...@@ -763,7 +768,7 @@ class UNet2DConditionLoadersMixin: ...@@ -763,7 +768,7 @@ class UNet2DConditionLoadersMixin:
image_projection.load_state_dict(updated_state_dict) image_projection.load_state_dict(updated_state_dict)
return image_projection return image_projection
def _load_ip_adapter_weights(self, state_dict): def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
from ..models.attention_processor import ( from ..models.attention_processor import (
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
...@@ -771,20 +776,6 @@ class UNet2DConditionLoadersMixin: ...@@ -771,20 +776,6 @@ class UNet2DConditionLoadersMixin:
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
) )
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict # set ip-adapter cross-attention processors & load state_dict
attn_procs = {} attn_procs = {}
key_id = 1 key_id = 1
...@@ -798,6 +789,7 @@ class UNet2DConditionLoadersMixin: ...@@ -798,6 +789,7 @@ class UNet2DConditionLoadersMixin:
elif name.startswith("down_blocks"): elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id] hidden_size = self.config.block_out_channels[block_id]
if cross_attention_dim is None or "motion_modules" in name: if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = ( attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
...@@ -807,6 +799,18 @@ class UNet2DConditionLoadersMixin: ...@@ -807,6 +799,18 @@ class UNet2DConditionLoadersMixin:
attn_processor_class = ( attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
) )
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds += [4]
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
attn_procs[name] = attn_processor_class( attn_procs[name] = attn_processor_class(
hidden_size=hidden_size, hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -815,16 +819,31 @@ class UNet2DConditionLoadersMixin: ...@@ -815,16 +819,31 @@ class UNet2DConditionLoadersMixin:
).to(dtype=self.dtype, device=self.device) ).to(dtype=self.dtype, device=self.device)
value_dict = {} value_dict = {}
for k, w in attn_procs[name].state_dict().items(): for i, state_dict in enumerate(state_dicts):
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
attn_procs[name].load_state_dict(value_dict) attn_procs[name].load_state_dict(value_dict)
key_id += 2 key_id += 2
return attn_procs
def _load_ip_adapter_weights(self, state_dicts):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
self.set_attn_processor(attn_procs) self.set_attn_processor(attn_procs)
# convert IP-Adapter Image Projection layers to diffusers # convert IP-Adapter Image Projection layers to diffusers
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
image_projection_layer.to(device=self.device, dtype=self.dtype)
image_projection_layers.append(image_projection_layer)
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj" self.config.encoder_hid_dim_type = "ip_image_proj"
...@@ -2087,29 +2087,41 @@ class LoRAAttnAddedKVProcessor(nn.Module): ...@@ -2087,29 +2087,41 @@ class LoRAAttnAddedKVProcessor(nn.Module):
class IPAdapterAttnProcessor(nn.Module): class IPAdapterAttnProcessor(nn.Module):
r""" r"""
Attention processor for IP-Adapater. Attention processor for Multiple IP-Adapater.
Args: Args:
hidden_size (`int`): hidden_size (`int`):
The hidden size of the attention layer. The hidden size of the attention layer.
cross_attention_dim (`int`): cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`. The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, defaults to 4): num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features. The context length of the image features.
scale (`float`, defaults to 1.0): scale (`float` or List[`float`], defaults to 1.0):
the weight scale of image prompt. the weight scale of image prompt.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0): def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens self.num_tokens = num_tokens
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_k_ip = nn.ModuleList(
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__( def __call__(
self, self,
...@@ -2120,10 +2132,24 @@ class IPAdapterAttnProcessor(nn.Module): ...@@ -2120,10 +2132,24 @@ class IPAdapterAttnProcessor(nn.Module):
temb=None, temb=None,
scale=1.0, scale=1.0,
): ):
if scale != 1.0:
logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
residual = hidden_states residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -2148,13 +2174,6 @@ class IPAdapterAttnProcessor(nn.Module): ...@@ -2148,13 +2174,6 @@ class IPAdapterAttnProcessor(nn.Module):
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -2167,17 +2186,20 @@ class IPAdapterAttnProcessor(nn.Module): ...@@ -2167,17 +2186,20 @@ class IPAdapterAttnProcessor(nn.Module):
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter # for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states) for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_value = self.to_v_ip(ip_hidden_states) ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key) ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value) ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
...@@ -2204,13 +2226,13 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2204,13 +2226,13 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
The hidden size of the attention layer. The hidden size of the attention layer.
cross_attention_dim (`int`): cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`. The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, defaults to 4): num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features. The context length of the image features.
scale (`float`, defaults to 1.0): scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt. the weight scale of image prompt.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0): def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__() super().__init__()
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
...@@ -2220,11 +2242,23 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2220,11 +2242,23 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens self.num_tokens = num_tokens
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_k_ip = nn.ModuleList(
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__( def __call__(
self, self,
...@@ -2235,10 +2269,24 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2235,10 +2269,24 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
temb=None, temb=None,
scale=1.0, scale=1.0,
): ):
if scale != 1.0:
logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
residual = hidden_states residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -2268,13 +2316,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2268,13 +2316,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -2296,22 +2337,27 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2296,22 +2337,27 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# for ip-adapter # for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states) for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_value = self.to_v_ip(ip_hidden_states) ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention( current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
) )
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
ip_hidden_states = ip_hidden_states.to(query.dtype) batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND from ..utils import USE_PEFT_BACKEND, deprecate
from .activations import get_activation from .activations import get_activation
from .attention_processor import Attention from .attention_processor import Attention
from .lora import LoRACompatibleLinear from .lora import LoRACompatibleLinear
...@@ -878,3 +878,38 @@ class IPAdapterPlusImageProjection(nn.Module): ...@@ -878,3 +878,38 @@ class IPAdapterPlusImageProjection(nn.Module):
latents = self.proj_out(latents) latents = self.proj_out(latents)
return self.norm_out(latents) return self.norm_out(latents)
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
def forward(self, image_embeds: List[torch.FloatTensor]):
projected_image_embeds = []
# currently, we accept `image_embeds` as
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
if not isinstance(image_embeds, list):
deprecation_message = (
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
)
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
image_embeds = [image_embeds.unsqueeze(1)]
if len(image_embeds) != len(self.image_projection_layers):
raise ValueError(
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
)
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
image_embed = image_projection_layer(image_embed)
image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
projected_image_embeds.append(image_embed)
return projected_image_embeds
...@@ -1074,8 +1074,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1074,8 +1074,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
) )
image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) image_embeds = self.encoder_hid_proj(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -426,6 +426,35 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap ...@@ -426,6 +426,35 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
...@@ -1002,12 +1031,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap ...@@ -1002,12 +1031,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_videos_per_prompt
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -506,6 +506,35 @@ class StableDiffusionControlNetPipeline( ...@@ -506,6 +506,35 @@ class StableDiffusionControlNetPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -1083,12 +1112,9 @@ class StableDiffusionControlNetPipeline( ...@@ -1083,12 +1112,9 @@ class StableDiffusionControlNetPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image # 4. Prepare image
if isinstance(controlnet, ControlNetModel): if isinstance(controlnet, ControlNetModel):
......
...@@ -499,6 +499,35 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -499,6 +499,35 @@ class StableDiffusionControlNetImg2ImgPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -1087,12 +1116,9 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1087,12 +1116,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image # 4. Prepare image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
......
...@@ -624,6 +624,35 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -624,6 +624,35 @@ class StableDiffusionControlNetInpaintPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -1335,12 +1364,9 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1335,12 +1364,9 @@ class StableDiffusionControlNetInpaintPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image # 4. Prepare image
if isinstance(controlnet, ControlNetModel): if isinstance(controlnet, ControlNetModel):
......
...@@ -515,6 +515,35 @@ class StableDiffusionXLControlNetPipeline( ...@@ -515,6 +515,35 @@ class StableDiffusionXLControlNetPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...@@ -1182,12 +1211,9 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1182,12 +1211,9 @@ class StableDiffusionXLControlNetPipeline(
# 3.2 Encode ip_adapter_image # 3.2 Encode ip_adapter_image
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image # 4. Prepare image
if isinstance(controlnet, ControlNetModel): if isinstance(controlnet, ControlNetModel):
......
...@@ -564,6 +564,35 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -564,6 +564,35 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...@@ -1340,12 +1369,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1340,12 +1369,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
# 3.2 Encode ip_adapter_image # 3.2 Encode ip_adapter_image
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image and controlnet_conditioning_image # 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
......
...@@ -1280,8 +1280,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1280,8 +1280,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
) )
image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) image_embeds = self.encoder_hid_proj(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -477,6 +477,35 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -477,6 +477,35 @@ class LatentConsistencyModelImg2ImgPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -790,9 +819,8 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -790,9 +819,8 @@ class LatentConsistencyModelImg2ImgPipeline(
# do_classifier_free_guidance = guidance_scale > 1.0 # do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
# 3. Encode input prompt # 3. Encode input prompt
......
...@@ -514,6 +514,34 @@ class StableDiffusionPipeline( ...@@ -514,6 +514,34 @@ class StableDiffusionPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
has_nsfw_concept = None has_nsfw_concept = None
...@@ -949,12 +977,9 @@ class StableDiffusionPipeline( ...@@ -949,12 +977,9 @@ class StableDiffusionPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
......
...@@ -528,6 +528,35 @@ class StableDiffusionImg2ImgPipeline( ...@@ -528,6 +528,35 @@ class StableDiffusionImg2ImgPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -1001,12 +1030,9 @@ class StableDiffusionImg2ImgPipeline( ...@@ -1001,12 +1030,9 @@ class StableDiffusionImg2ImgPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Preprocess image # 4. Preprocess image
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image)
......
...@@ -600,6 +600,35 @@ class StableDiffusionInpaintPipeline( ...@@ -600,6 +600,35 @@ class StableDiffusionInpaintPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -1209,12 +1238,9 @@ class StableDiffusionInpaintPipeline( ...@@ -1209,12 +1238,9 @@ class StableDiffusionInpaintPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. set timesteps # 4. set timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
......
...@@ -438,6 +438,35 @@ class StableDiffusionLDM3DPipeline( ...@@ -438,6 +438,35 @@ class StableDiffusionLDM3DPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
has_nsfw_concept = None has_nsfw_concept = None
...@@ -654,12 +683,9 @@ class StableDiffusionLDM3DPipeline( ...@@ -654,12 +683,9 @@ class StableDiffusionLDM3DPipeline(
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
......
...@@ -396,6 +396,35 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -396,6 +396,35 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -669,12 +698,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -669,12 +698,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
......
...@@ -549,6 +549,35 @@ class StableDiffusionXLPipeline( ...@@ -549,6 +549,35 @@ class StableDiffusionXLPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...@@ -1163,13 +1192,9 @@ class StableDiffusionXLPipeline( ...@@ -1163,13 +1192,9 @@ class StableDiffusionXLPipeline(
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment