Unverified Commit 68e96239 authored by Fabio Rigano's avatar Fabio Rigano Committed by GitHub
Browse files

Add converter method for ip adapters (#6150)



* Add converter method for ip adapters

* Move converter method

* Update to image proj converter

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 781775ea
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict, defaultdict
from collections import defaultdict
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
......@@ -664,6 +664,80 @@ class UNet2DConditionLoadersMixin:
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
updated_state_dict = {}
image_projection = None
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
elif "proj.3.weight" in state_dict:
# IP-Adapter Full
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
image_projection = MLPProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
diffusers_name = diffusers_name.replace("proj.3", "norm")
updated_state_dict[diffusers_name] = value
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["latents"].shape[1]
embed_dims = state_dict["proj_in.weight"].shape[1]
output_dims = state_dict["proj_out.weight"].shape[0]
hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("0.to", "2.to")
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
elif "norm2" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
elif "to_kv" in diffusers_name:
v_chunk = value.chunk(2, dim=0)
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in diffusers_name:
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
else:
updated_state_dict[diffusers_name] = value
image_projection.load_state_dict(updated_state_dict)
return image_projection
def _load_ip_adapter_weights(self, state_dict):
from ..models.attention_processor import (
AttnProcessor,
......@@ -724,103 +798,8 @@ class UNet2DConditionLoadersMixin:
self.set_attn_processor(attn_procs)
# create image projection layers.
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
image_projection.to(dtype=self.dtype, device=self.device)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
elif "proj.3.weight" in state_dict["image_proj"]:
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0]
image_projection = MLPProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
image_projection.to(dtype=self.dtype, device=self.device)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
else:
# IP-Adapter Plus
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)
image_proj_state_dict = state_dict["image_proj"]
new_sd = OrderedDict()
for k, v in image_proj_state_dict.items():
if "0.to" in k:
k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
k = k.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in k:
new_sd[k.replace("0.norm1", "0")] = v
elif "norm2" in k:
new_sd[k.replace("0.norm2", "1")] = v
elif "to_kv" in k:
v_chunk = v.chunk(2, dim=0)
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in k:
new_sd[k.replace("to_out", "to_out.0")] = v
else:
new_sd[k] = v
image_projection.load_state_dict(new_sd)
del image_proj_state_dict
# convert IP-Adapter Image Projection layers to diffusers
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
delete_adapter_layers
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment