Unverified Commit ba1bfac2 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] Refactor `IPAdapterPlusImageProjection` a bit (#7994)

* use IPAdapterPlusImageProjectionBlock in IPAdapterPlusImageProjection

* reposition IPAdapterPlusImageProjection

* refactor complete?

* fix heads param retrieval.

* update test dict creation method.
parent 5edd0b34
...@@ -847,7 +847,12 @@ class UNet2DConditionLoadersMixin: ...@@ -847,7 +847,12 @@ class UNet2DConditionLoadersMixin:
embed_dims = state_dict["proj_in.weight"].shape[1] embed_dims = state_dict["proj_in.weight"].shape[1]
output_dims = state_dict["proj_out.weight"].shape[0] output_dims = state_dict["proj_out.weight"].shape[0]
hidden_dims = state_dict["latents"].shape[2] hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 attn_key_present = any("attn" in k for k in state_dict)
heads = (
state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
if attn_key_present
else state_dict["layers.0.0.to_q.weight"].shape[0] // 64
)
with init_context(): with init_context():
image_projection = IPAdapterPlusImageProjection( image_projection = IPAdapterPlusImageProjection(
...@@ -860,26 +865,53 @@ class UNet2DConditionLoadersMixin: ...@@ -860,26 +865,53 @@ class UNet2DConditionLoadersMixin:
for key, value in state_dict.items(): for key, value in state_dict.items():
diffusers_name = key.replace("0.to", "2.to") diffusers_name = key.replace("0.to", "2.to")
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in diffusers_name: diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0")
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1")
elif "norm2" in diffusers_name: diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0")
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1")
elif "to_kv" in diffusers_name: diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0")
diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1")
diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0")
diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1")
if "to_kv" in diffusers_name:
parts = diffusers_name.split(".")
parts[2] = "attn"
diffusers_name = ".".join(parts)
v_chunk = value.chunk(2, dim=0) v_chunk = value.chunk(2, dim=0)
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_q" in diffusers_name:
parts = diffusers_name.split(".")
parts[2] = "attn"
diffusers_name = ".".join(parts)
updated_state_dict[diffusers_name] = value
elif "to_out" in diffusers_name: elif "to_out" in diffusers_name:
parts = diffusers_name.split(".")
parts[2] = "attn"
diffusers_name = ".".join(parts)
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
else: else:
diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0")
diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj")
diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2")
diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0")
diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj")
diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2")
diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0")
diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj")
diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2")
diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0")
diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj")
diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2")
updated_state_dict[diffusers_name] = value updated_state_dict[diffusers_name] = value
if not low_cpu_mem_usage: if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict) image_projection.load_state_dict(updated_state_dict, strict=True)
else: else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
......
...@@ -806,6 +806,39 @@ class PixArtAlphaTextProjection(nn.Module): ...@@ -806,6 +806,39 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states return hidden_states
class IPAdapterPlusImageProjectionBlock(nn.Module):
def __init__(
self,
embed_dims: int = 768,
dim_head: int = 64,
heads: int = 16,
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward
self.ln0 = nn.LayerNorm(embed_dims)
self.ln1 = nn.LayerNorm(embed_dims)
self.attn = Attention(
query_dim=embed_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
)
self.ff = nn.Sequential(
nn.LayerNorm(embed_dims),
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
)
def forward(self, x, latents, residual):
encoder_hidden_states = self.ln0(x)
latents = self.ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = self.attn(latents, encoder_hidden_states) + residual
latents = self.ff(latents) + latents
return latents
class IPAdapterPlusImageProjection(nn.Module): class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus. """Resampler of IP-Adapter Plus.
...@@ -834,8 +867,6 @@ class IPAdapterPlusImageProjection(nn.Module): ...@@ -834,8 +867,6 @@ class IPAdapterPlusImageProjection(nn.Module):
ffn_ratio: float = 4, ffn_ratio: float = 4,
) -> None: ) -> None:
super().__init__() super().__init__()
from .attention import FeedForward # Lazy import to avoid circular import
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
self.proj_in = nn.Linear(embed_dims, hidden_dims) self.proj_in = nn.Linear(embed_dims, hidden_dims)
...@@ -843,26 +874,9 @@ class IPAdapterPlusImageProjection(nn.Module): ...@@ -843,26 +874,9 @@ class IPAdapterPlusImageProjection(nn.Module):
self.proj_out = nn.Linear(hidden_dims, output_dims) self.proj_out = nn.Linear(hidden_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims) self.norm_out = nn.LayerNorm(output_dims)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList(
for _ in range(depth): [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
self.layers.append( )
nn.ModuleList(
[
nn.LayerNorm(hidden_dims),
nn.LayerNorm(hidden_dims),
Attention(
query_dim=hidden_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
),
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
]
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass. """Forward pass.
...@@ -876,52 +890,14 @@ class IPAdapterPlusImageProjection(nn.Module): ...@@ -876,52 +890,14 @@ class IPAdapterPlusImageProjection(nn.Module):
x = self.proj_in(x) x = self.proj_in(x)
for ln0, ln1, attn, ff in self.layers: for block in self.layers:
residual = latents residual = latents
latents = block(x, latents, residual)
encoder_hidden_states = ln0(x)
latents = ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = attn(latents, encoder_hidden_states) + residual
latents = ff(latents) + latents
latents = self.proj_out(latents) latents = self.proj_out(latents)
return self.norm_out(latents) return self.norm_out(latents)
class IPAdapterPlusImageProjectionBlock(nn.Module):
def __init__(
self,
embed_dims: int = 768,
dim_head: int = 64,
heads: int = 16,
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward
self.ln0 = nn.LayerNorm(embed_dims)
self.ln1 = nn.LayerNorm(embed_dims)
self.attn = Attention(
query_dim=embed_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
)
self.ff = nn.Sequential(
nn.LayerNorm(embed_dims),
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
)
def forward(self, x, latents, residual):
encoder_hidden_states = self.ln0(x)
latents = self.ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = self.attn(latents, encoder_hidden_states) + residual
latents = self.ff(latents) + latents
return latents
class IPAdapterFaceIDPlusImageProjection(nn.Module): class IPAdapterFaceIDPlusImageProjection(nn.Module):
"""FacePerceiverResampler of IP-Adapter Plus. """FacePerceiverResampler of IP-Adapter Plus.
......
...@@ -146,42 +146,64 @@ def create_ip_adapter_plus_state_dict(model): ...@@ -146,42 +146,64 @@ def create_ip_adapter_plus_state_dict(model):
) )
ip_image_projection_state_dict = OrderedDict() ip_image_projection_state_dict = OrderedDict()
keys = [k for k in image_projection.state_dict() if "layers." in k]
print(keys)
for k, v in image_projection.state_dict().items(): for k, v in image_projection.state_dict().items():
if "2.to" in k: if "2.to" in k:
k = k.replace("2.to", "0.to") k = k.replace("2.to", "0.to")
elif "3.0.weight" in k: elif "layers.0.ln0" in k:
k = k.replace("3.0.weight", "1.0.weight") k = k.replace("layers.0.ln0", "layers.0.0.norm1")
elif "3.0.bias" in k: elif "layers.0.ln1" in k:
k = k.replace("3.0.bias", "1.0.bias") k = k.replace("layers.0.ln1", "layers.0.0.norm2")
elif "3.0.weight" in k: elif "layers.1.ln0" in k:
k = k.replace("3.0.weight", "1.0.weight") k = k.replace("layers.1.ln0", "layers.1.0.norm1")
elif "3.1.net.0.proj.weight" in k: elif "layers.1.ln1" in k:
k = k.replace("3.1.net.0.proj.weight", "1.1.weight") k = k.replace("layers.1.ln1", "layers.1.0.norm2")
elif "3.net.2.weight" in k: elif "layers.2.ln0" in k:
k = k.replace("3.net.2.weight", "1.3.weight") k = k.replace("layers.2.ln0", "layers.2.0.norm1")
elif "layers.0.0" in k: elif "layers.2.ln1" in k:
k = k.replace("layers.0.0", "layers.0.0.norm1") k = k.replace("layers.2.ln1", "layers.2.0.norm2")
elif "layers.0.1" in k: elif "layers.3.ln0" in k:
k = k.replace("layers.0.1", "layers.0.0.norm2") k = k.replace("layers.3.ln0", "layers.3.0.norm1")
elif "layers.1.0" in k: elif "layers.3.ln1" in k:
k = k.replace("layers.1.0", "layers.1.0.norm1") k = k.replace("layers.3.ln1", "layers.3.0.norm2")
elif "layers.1.1" in k: elif "to_q" in k:
k = k.replace("layers.1.1", "layers.1.0.norm2") parts = k.split(".")
elif "layers.2.0" in k: parts[2] = "attn"
k = k.replace("layers.2.0", "layers.2.0.norm1") k = ".".join(parts)
elif "layers.2.1" in k: elif "to_out.0" in k:
k = k.replace("layers.2.1", "layers.2.0.norm2") parts = k.split(".")
parts[2] = "attn"
if "norm_cross" in k: k = ".".join(parts)
ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v k = k.replace("to_out.0", "to_out")
elif "layer_norm" in k: else:
ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v k = k.replace("0.ff.0", "0.1.0")
elif "to_k" in k: k = k.replace("0.ff.1.net.0.proj", "0.1.1")
k = k.replace("0.ff.1.net.2", "0.1.3")
k = k.replace("1.ff.0", "1.1.0")
k = k.replace("1.ff.1.net.0.proj", "1.1.1")
k = k.replace("1.ff.1.net.2", "1.1.3")
k = k.replace("2.ff.0", "2.1.0")
k = k.replace("2.ff.1.net.0.proj", "2.1.1")
k = k.replace("2.ff.1.net.2", "2.1.3")
k = k.replace("3.ff.0", "3.1.0")
k = k.replace("3.ff.1.net.0.proj", "3.1.1")
k = k.replace("3.ff.1.net.2", "3.1.3")
# if "norm_cross" in k:
# ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v
# elif "layer_norm" in k:
# ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v
if "to_k" in k:
parts = k.split(".")
parts[2] = "attn"
k = ".".join(parts)
ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0) ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0)
elif "to_v" in k: elif "to_v" in k:
continue continue
elif "to_out.0" in k:
ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v
else: else:
ip_image_projection_state_dict[k] = v ip_image_projection_state_dict[k] = v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment