extract_weights.py 3.79 KB
Newer Older
ai_public's avatar
ai_public committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import safetensors

ckpt = "/home/modelzoo/IP-Adapter/test_output/checkpoint-120000/model.safetensors"

sd = {}
with safetensors.safe_open(ckpt, framework="pt", device='cpu') as f:
    for k in f.keys():
        sd[k] = f.get_tensor(k)

image_proj_sd = {}
ip_sd = {}

names_1 = ['down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight']

names_2 = [
"1.to_k_ip.weight", "1.to_v_ip.weight", "3.to_k_ip.weight", "3.to_v_ip.weight", "5.to_k_ip.weight", "5.to_v_ip.weight", "7.to_k_ip.weight", "7.to_v_ip.weight", "9.to_k_ip.weight", "9.to_v_ip.weight", "11.to_k_ip.weight", "11.to_v_ip.weight", "13.to_k_ip.weight", "13.to_v_ip.weight", "15.to_k_ip.weight", "15.to_v_ip.weight", "17.to_k_ip.weight", "17.to_v_ip.weight", "19.to_k_ip.weight", "19.to_v_ip.weight", "21.to_k_ip.weight", "21.to_v_ip.weight", "23.to_k_ip.weight", "23.to_v_ip.weight", "25.to_k_ip.weight", "25.to_v_ip.weight", "27.to_k_ip.weight", "27.to_v_ip.weight", "29.to_k_ip.weight", "29.to_v_ip.weight", "31.to_k_ip.weight", "31.to_v_ip.weight"
]

mapping = {k: v for k, v in zip(names_1, names_2)}


for k in sd:
    if k.startswith("image_proj_model"):
        image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
    elif "_ip." in k: 
        ip_sd[mapping[k.replace("unet.", "")]] = sd[k] 


torch.save({"image_proj": image_proj_sd, "ip_adapter": ip_sd}, "ip_adapter.bin")