Unverified Commit 0a08d419 authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[Feature] Support IP-Adapter Plus (#5915)



* Support IP-Adapter Plus

* fix format

* restore before black format

* restore before black format

* generic

* Refactor PerceiverAttention

* format

* fix test and refactor PerceiverAttention

* generic encode_image

* keep attention implementation

* merge tests

* encode_image backward compatible

* code quality

* fix controlnet inpaint pipeline

* refactor FFN

* refactor FFN

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent e185084a
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from collections import defaultdict from collections import OrderedDict, defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
...@@ -21,7 +21,7 @@ import torch ...@@ -21,7 +21,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..models.embeddings import ImageProjection from ..models.embeddings import ImageProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import ( from ..utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
...@@ -672,6 +672,17 @@ class UNet2DConditionLoadersMixin: ...@@ -672,6 +672,17 @@ class UNet2DConditionLoadersMixin:
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
) )
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`.
self.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict # set ip-adapter cross-attention processors & load state_dict
attn_procs = {} attn_procs = {}
key_id = 1 key_id = 1
...@@ -695,7 +706,10 @@ class UNet2DConditionLoadersMixin: ...@@ -695,7 +706,10 @@ class UNet2DConditionLoadersMixin:
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
) )
attn_procs[name] = attn_processor_class( attn_procs[name] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device) ).to(dtype=self.dtype, device=self.device)
value_dict = {} value_dict = {}
...@@ -708,26 +722,76 @@ class UNet2DConditionLoadersMixin: ...@@ -708,26 +722,76 @@ class UNet2DConditionLoadersMixin:
self.set_attn_processor(attn_procs) self.set_attn_processor(attn_procs)
# create image projection layers. # create image projection layers.
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] if "proj.weight" in state_dict["image_proj"]:
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 # IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
image_projection.to(dtype=self.dtype, device=self.device)
image_projection = ImageProjection( # load image projection layer weights
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 image_proj_state_dict = {}
) image_proj_state_dict.update(
image_projection.to(dtype=self.dtype, device=self.device) {
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
# load image projection layer weights "image_embeds.bias": state_dict["image_proj"]["proj.bias"],
image_proj_state_dict = {} "norm.weight": state_dict["image_proj"]["norm.weight"],
image_proj_state_dict.update( "norm.bias": state_dict["image_proj"]["norm.bias"],
{ }
"image_embeds.weight": state_dict["image_proj"]["proj.weight"], )
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"], image_projection.load_state_dict(image_proj_state_dict)
"norm.bias": state_dict["image_proj"]["norm.bias"],
} else:
) # IP-Adapter Plus
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)
image_proj_state_dict = state_dict["image_proj"]
new_sd = OrderedDict()
for k, v in image_proj_state_dict.items():
if "0.to" in k:
k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
k = k.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in k:
new_sd[k.replace("0.norm1", "0")] = v
elif "norm2" in k:
new_sd[k.replace("0.norm2", "1")] = v
elif "to_kv" in k:
v_chunk = v.chunk(2, dim=0)
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in k:
new_sd[k.replace("to_out", "to_out.0")] = v
else:
new_sd[k] = v
image_projection.load_state_dict(image_proj_state_dict) image_projection.load_state_dict(new_sd)
del image_proj_state_dict
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj" self.config.encoder_hid_dim_type = "ip_image_proj"
......
...@@ -34,6 +34,7 @@ if is_torch_available(): ...@@ -34,6 +34,7 @@ if is_torch_available():
_import_structure["controlnet"] = ["ControlNetModel"] _import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["prior_transformer"] = ["PriorTransformer"] _import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"] _import_structure["transformer_2d"] = ["Transformer2DModel"]
...@@ -63,6 +64,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -63,6 +64,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .consistency_decoder_vae import ConsistencyDecoderVAE from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder from .t5_film_transformer import T5FilmDecoder
......
...@@ -55,11 +55,12 @@ class GELU(nn.Module): ...@@ -55,11 +55,12 @@ class GELU(nn.Module):
dim_in (`int`): The number of channels in the input. dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output. dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
""" """
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out) self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor: def gelu(self, gate: torch.Tensor) -> torch.Tensor:
...@@ -81,13 +82,14 @@ class GEGLU(nn.Module): ...@@ -81,13 +82,14 @@ class GEGLU(nn.Module):
Parameters: Parameters:
dim_in (`int`): The number of channels in the input. dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output. dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
""" """
def __init__(self, dim_in: int, dim_out: int): def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__() super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2) self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor: def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps": if gate.device.type != "mps":
...@@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module): ...@@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module):
Parameters: Parameters:
dim_in (`int`): The number of channels in the input. dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output. dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
""" """
def __init__(self, dim_in: int, dim_out: int): def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out) self.proj = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x) x = self.proj(x)
......
...@@ -501,6 +501,7 @@ class FeedForward(nn.Module): ...@@ -501,6 +501,7 @@ class FeedForward(nn.Module):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
""" """
def __init__( def __init__(
...@@ -511,6 +512,7 @@ class FeedForward(nn.Module): ...@@ -511,6 +512,7 @@ class FeedForward(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
activation_fn: str = "geglu", activation_fn: str = "geglu",
final_dropout: bool = False, final_dropout: bool = False,
bias: bool = True,
): ):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
...@@ -518,13 +520,13 @@ class FeedForward(nn.Module): ...@@ -518,13 +520,13 @@ class FeedForward(nn.Module):
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu": if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim) act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate": if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh") act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu": elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim) act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate": elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim) act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([]) self.net = nn.ModuleList([])
# project in # project in
...@@ -532,7 +534,7 @@ class FeedForward(nn.Module): ...@@ -532,7 +534,7 @@ class FeedForward(nn.Module):
# project dropout # project dropout
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
# project out # project out
self.net.append(linear_cls(inner_dim, dim_out)) self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout: if final_dropout:
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
......
...@@ -20,6 +20,7 @@ from torch import nn ...@@ -20,6 +20,7 @@ from torch import nn
from ..utils import USE_PEFT_BACKEND from ..utils import USE_PEFT_BACKEND
from .activations import get_activation from .activations import get_activation
from .attention_processor import Attention
from .lora import LoRACompatibleLinear from .lora import LoRACompatibleLinear
...@@ -790,3 +791,91 @@ class CaptionProjection(nn.Module): ...@@ -790,3 +791,91 @@ class CaptionProjection(nn.Module):
hidden_states = self.act_1(hidden_states) hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states) hidden_states = self.linear_2(hidden_states)
return hidden_states return hidden_states
class Resampler(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
----
embed_dims (int): The feature dimension. Defaults to 768.
output_dims (int): The number of output channels, that is the same
number of the channels in the
`unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int): The number of hidden channels. Defaults to 1280.
depth (int): The number of blocks. Defaults to 8.
dim_head (int): The number of head channels. Defaults to 64.
heads (int): Parallel attention heads. Defaults to 16.
num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
"""
def __init__(
self,
embed_dims: int = 768,
output_dims: int = 1024,
hidden_dims: int = 1280,
depth: int = 4,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward # Lazy import to avoid circular import
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
self.proj_in = nn.Linear(embed_dims, hidden_dims)
self.proj_out = nn.Linear(hidden_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
nn.LayerNorm(hidden_dims),
nn.LayerNorm(hidden_dims),
Attention(
query_dim=hidden_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
),
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
]
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
----
x (torch.Tensor): Input Tensor.
Returns:
-------
torch.Tensor: Output Tensor.
"""
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for ln0, ln1, attn, ff in self.layers:
residual = latents
encoder_hidden_states = ln0(x)
latents = ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = attn(latents, encoder_hidden_states) + residual
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
...@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMR ...@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMR
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -494,18 +494,29 @@ class AltDiffusionPipeline( ...@@ -494,18 +494,29 @@ class AltDiffusionPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -875,7 +886,10 @@ class AltDiffusionPipeline( ...@@ -875,7 +886,10 @@ class AltDiffusionPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMR ...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMR
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -505,18 +505,29 @@ class AltDiffusionImg2ImgPipeline( ...@@ -505,18 +505,29 @@ class AltDiffusionImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -919,7 +930,10 @@ class AltDiffusionImg2ImgPipeline( ...@@ -919,7 +930,10 @@ class AltDiffusionImg2ImgPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter from ...models.unet_motion_model import MotionAdapter
from ...schedulers import ( from ...schedulers import (
...@@ -320,18 +320,29 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap ...@@ -320,18 +320,29 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
...@@ -651,7 +662,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap ...@@ -651,7 +662,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
)
if do_classifier_free_guidance: if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -479,18 +479,29 @@ class StableDiffusionControlNetPipeline( ...@@ -479,18 +479,29 @@ class StableDiffusionControlNetPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -1067,7 +1078,10 @@ class StableDiffusionControlNetPipeline( ...@@ -1067,7 +1078,10 @@ class StableDiffusionControlNetPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -597,18 +597,29 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -597,18 +597,29 @@ class StableDiffusionControlNetInpaintPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -1284,7 +1295,10 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1284,7 +1295,10 @@ class StableDiffusionControlNetInpaintPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -37,7 +37,7 @@ from ...loaders import ( ...@@ -37,7 +37,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -489,18 +489,29 @@ class StableDiffusionXLControlNetPipeline( ...@@ -489,18 +489,29 @@ class StableDiffusionXLControlNetPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
...@@ -1169,7 +1180,10 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1169,7 +1180,10 @@ class StableDiffusionXLControlNetPipeline(
# 3.2 Encode ip_adapter_image # 3.2 Encode ip_adapter_image
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -489,18 +489,29 @@ class StableDiffusionPipeline( ...@@ -489,18 +489,29 @@ class StableDiffusionPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -871,7 +882,10 @@ class StableDiffusionPipeline( ...@@ -871,7 +882,10 @@ class StableDiffusionPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -503,18 +503,29 @@ class StableDiffusionImg2ImgPipeline( ...@@ -503,18 +503,29 @@ class StableDiffusionImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -923,7 +934,10 @@ class StableDiffusionImg2ImgPipeline( ...@@ -923,7 +934,10 @@ class StableDiffusionImg2ImgPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
...@@ -574,18 +574,29 @@ class StableDiffusionInpaintPipeline( ...@@ -574,18 +574,29 @@ class StableDiffusionInpaintPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -1103,7 +1114,10 @@ class StableDiffusionInpaintPipeline( ...@@ -1103,7 +1114,10 @@ class StableDiffusionInpaintPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
......
...@@ -31,7 +31,7 @@ from ...loaders import ( ...@@ -31,7 +31,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -524,18 +524,29 @@ class StableDiffusionXLPipeline( ...@@ -524,18 +524,29 @@ class StableDiffusionXLPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
...@@ -1087,7 +1098,10 @@ class StableDiffusionXLPipeline( ...@@ -1087,7 +1098,10 @@ class StableDiffusionXLPipeline(
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device) image_embeds = image_embeds.to(device)
......
...@@ -32,7 +32,7 @@ from ...loaders import ( ...@@ -32,7 +32,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -741,18 +741,29 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -741,18 +741,29 @@ class StableDiffusionXLImg2ImgPipeline(
return latents return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
def _get_add_time_ids( def _get_add_time_ids(
self, self,
...@@ -1259,7 +1270,10 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1259,7 +1270,10 @@ class StableDiffusionXLImg2ImgPipeline(
add_time_ids = add_time_ids.to(device) add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device) image_embeds = image_embeds.to(device)
......
...@@ -33,7 +33,7 @@ from ...loaders import ( ...@@ -33,7 +33,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -462,18 +462,29 @@ class StableDiffusionXLInpaintPipeline( ...@@ -462,18 +462,29 @@ class StableDiffusionXLInpaintPipeline(
self.vae.disable_tiling() self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds if output_hidden_states:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt( def encode_prompt(
...@@ -1568,7 +1579,10 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1568,7 +1579,10 @@ class StableDiffusionXLInpaintPipeline(
add_time_ids = add_time_ids.to(device) add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None: if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device) image_embeds = image_embeds.to(device)
......
...@@ -18,6 +18,7 @@ import gc ...@@ -18,6 +18,7 @@ import gc
import os import os
import tempfile import tempfile
import unittest import unittest
from collections import OrderedDict
import torch import torch
from parameterized import parameterized from parameterized import parameterized
...@@ -25,7 +26,7 @@ from pytest import mark ...@@ -25,7 +26,7 @@ from pytest import mark
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.embeddings import ImageProjection from diffusers.models.embeddings import ImageProjection, Resampler
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
...@@ -97,6 +98,85 @@ def create_ip_adapter_state_dict(model): ...@@ -97,6 +98,85 @@ def create_ip_adapter_state_dict(model):
return ip_state_dict return ip_state_dict
def create_ip_adapter_plus_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 1
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
if cross_attention_dim is not None:
sd = IPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
}
)
key_id += 2
# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = Resampler(
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
)
ip_image_projection_state_dict = OrderedDict()
for k, v in image_projection.state_dict().items():
if "2.to" in k:
k = k.replace("2.to", "0.to")
elif "3.0.weight" in k:
k = k.replace("3.0.weight", "1.0.weight")
elif "3.0.bias" in k:
k = k.replace("3.0.bias", "1.0.bias")
elif "3.0.weight" in k:
k = k.replace("3.0.weight", "1.0.weight")
elif "3.1.net.0.proj.weight" in k:
k = k.replace("3.1.net.0.proj.weight", "1.1.weight")
elif "3.net.2.weight" in k:
k = k.replace("3.net.2.weight", "1.3.weight")
elif "layers.0.0" in k:
k = k.replace("layers.0.0", "layers.0.0.norm1")
elif "layers.0.1" in k:
k = k.replace("layers.0.1", "layers.0.0.norm2")
elif "layers.1.0" in k:
k = k.replace("layers.1.0", "layers.1.0.norm1")
elif "layers.1.1" in k:
k = k.replace("layers.1.1", "layers.1.0.norm2")
elif "layers.2.0" in k:
k = k.replace("layers.2.0", "layers.2.0.norm1")
elif "layers.2.1" in k:
k = k.replace("layers.2.1", "layers.2.0.norm2")
if "norm_cross" in k:
ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v
elif "layer_norm" in k:
ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v
elif "to_k" in k:
ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0)
elif "to_v" in k:
continue
elif "to_out.0" in k:
ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v
else:
ip_image_projection_state_dict[k] = v
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
def create_custom_diffusion_layers(model, mock_weights: bool = True): def create_custom_diffusion_layers(model, mock_weights: bool = True):
train_kv = True train_kv = True
train_q_out = True train_q_out = True
...@@ -724,6 +804,56 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -724,6 +804,56 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
def test_ip_adapter_plus(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without ip-adapter
with torch.no_grad():
sample1 = model(**inputs_dict).sample
# update inputs_dict for ip-adapter
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
# make ip_adapter_1 and ip_adapter_2
ip_adapter_1 = create_ip_adapter_plus_state_dict(model)
image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()}
cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()}
ip_adapter_2 = {}
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
# forward pass ip_adapter_1
model._load_ip_adapter_weights(ip_adapter_1)
assert model.config.encoder_hid_dim_type == "ip_image_proj"
assert model.encoder_hid_proj is not None
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
"IPAdapterAttnProcessor",
"IPAdapterAttnProcessor2_0",
)
with torch.no_grad():
sample2 = model(**inputs_dict).sample
# forward pass with ip_adapter_2
model._load_ip_adapter_weights(ip_adapter_2)
with torch.no_grad():
sample3 = model(**inputs_dict).sample
# forward pass with ip_adapter_1 again
model._load_ip_adapter_weights(ip_adapter_1)
with torch.no_grad():
sample4 = model(**inputs_dict).sample
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
@slow @slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase): class UNet2DConditionModelIntegrationTests(unittest.TestCase):
......
...@@ -116,7 +116,17 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -116,7 +116,17 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
images = pipeline(**inputs).images images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten() image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0]) expected_slice = np.array([0.8110, 0.8843, 0.9326, 0.9224, 0.9878, 1.0, 0.9736, 1.0, 1.0])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
inputs = self.get_dummy_inputs()
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.3013, 0.2615, 0.2202, 0.2722, 0.2510, 0.2023, 0.2498, 0.2415, 0.2139])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
...@@ -132,7 +142,17 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -132,7 +142,17 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
images = pipeline(**inputs).images images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten() image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935]) expected_slice = np.array([0.2253, 0.2251, 0.2219, 0.2312, 0.2236, 0.2434, 0.2275, 0.2575, 0.2805])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
inputs = self.get_dummy_inputs(for_image_to_image=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.3550, 0.2600, 0.2520, 0.2412, 0.1870, 0.3831, 0.1453, 0.1880, 0.5371])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
...@@ -148,7 +168,17 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -148,7 +168,17 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
images = pipeline(**inputs).images images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten() image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997]) expected_slice = np.array([0.2700, 0.2388, 0.2202, 0.2304, 0.2095, 0.2097, 0.2173, 0.2058, 0.1987])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
inputs = self.get_dummy_inputs(for_inpainting=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.2744, 0.2410, 0.2202, 0.2334, 0.2090, 0.2053, 0.2175, 0.2033, 0.1934])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
...@@ -173,7 +203,30 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -173,7 +203,30 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
images = pipeline(**inputs).images images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten() image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923]) expected_slice = np.array([0.0965, 0.0956, 0.0849, 0.0908, 0.0944, 0.0927, 0.0888, 0.0929, 0.0920])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter-plus_sdxl_vit-h.bin",
)
inputs = self.get_dummy_inputs()
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.0592, 0.0573, 0.0459, 0.0542, 0.0559, 0.0523, 0.0500, 0.0540, 0.0501])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
...@@ -194,7 +247,31 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -194,7 +247,31 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
images = pipeline(**inputs).images images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten() image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752]) expected_slice = np.array([0.0652, 0.0698, 0.0723, 0.0744, 0.0699, 0.0636, 0.0784, 0.0803, 0.0742])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter-plus_sdxl_vit-h.bin",
)
inputs = self.get_dummy_inputs(for_image_to_image=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.0708, 0.0701, 0.0735, 0.0760, 0.0739, 0.0679, 0.0756, 0.0824, 0.0837])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
...@@ -216,6 +293,31 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -216,6 +293,31 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
image_slice = images[0, :3, :3, -1].flatten() image_slice = images[0, :3, :3, -1].flatten()
image_slice.tolist() image_slice.tolist()
expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516]) expected_slice = np.array([0.1420, 0.1495, 0.1430, 0.1462, 0.1493, 0.1502, 0.1474, 0.1502, 0.1517])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter-plus_sdxl_vit-h.bin",
)
inputs = self.get_dummy_inputs(for_inpainting=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
image_slice.tolist()
expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment