sd.py 13.7 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
import torch

import sd1_clip
import sd2_clip
5
import model_management
comfyanonymous's avatar
comfyanonymous committed
6
7
8
9
from ldm.util import instantiate_from_config
from ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf

10
def load_torch_file(ckpt):
comfyanonymous's avatar
comfyanonymous committed
11
12
13
14
15
16
17
    if ckpt.lower().endswith(".safetensors"):
        import safetensors.torch
        sd = safetensors.torch.load_file(ckpt, device="cpu")
    else:
        pl_sd = torch.load(ckpt, map_location="cpu")
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
18
19
20
21
        if "state_dict" in pl_sd:
            sd = pl_sd["state_dict"]
        else:
            sd = pl_sd
22
23
24
25
26
27
    return sd

def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
    print(f"Loading model from {ckpt}")

    sd = load_torch_file(ckpt)
comfyanonymous's avatar
comfyanonymous committed
28
29
30
31
32
33
34
35
36
37
38
    model = instantiate_from_config(config.model)

    m, u = model.load_state_dict(sd, strict=False)

    k = list(sd.keys())
    for x in k:
        # print(x)
        if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
            y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
            sd[y] = sd.pop(x)

comfyanonymous's avatar
comfyanonymous committed
39
40
41
42
    if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
        ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids']
        if ids.dtype == torch.float32:
            sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    keys_to_replace = {
        "cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
        "cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
        "cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
        "cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
    }

    for x in keys_to_replace:
        if x in sd:
            sd[keys_to_replace[x]] = sd.pop(x)

    resblock_to_replace = {
        "ln_1": "layer_norm1",
        "ln_2": "layer_norm2",
        "mlp.c_fc": "mlp.fc1",
        "mlp.c_proj": "mlp.fc2",
        "attn.out_proj": "self_attn.out_proj",
    }

    for resblock in range(24):
        for x in resblock_to_replace:
            for y in ["weight", "bias"]:
                k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y)
                k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y)
                if k in sd:
                    sd[k_to] = sd.pop(k)

        for y in ["weight", "bias"]:
            k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y)
            if k_from in sd:
                weights = sd.pop(k_from)
                for x in range(3):
                    p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
                    k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y)
                    sd[k_to] = weights[1024*x:1024*(x + 1)]

comfyanonymous's avatar
comfyanonymous committed
80
81
82
83
84
85
86
87
88
89
90
91
92
    for x in load_state_dict_to:
        x.load_state_dict(sd, strict=False)

    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.eval()
    return model

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
LORA_CLIP_MAP = {
    "mlp.fc1": "mlp_fc1",
    "mlp.fc2": "mlp_fc2",
    "self_attn.k_proj": "self_attn_k_proj",
    "self_attn.q_proj": "self_attn_q_proj",
    "self_attn.v_proj": "self_attn_v_proj",
    "self_attn.out_proj": "self_attn_out_proj",
}

LORA_UNET_MAP = {
    "proj_in": "proj_in",
    "proj_out": "proj_out",
    "transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
    "transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k",
    "transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v",
    "transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0",
    "transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q",
    "transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k",
    "transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v",
    "transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0",
    "transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj",
    "transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
}


def load_lora(path, to_load):
    lora = load_torch_file(path)
    patch_dict = {}
    loaded_keys = set()
    for x in to_load:
        A_name = "{}.lora_up.weight".format(x)
        B_name = "{}.lora_down.weight".format(x)
        alpha_name = "{}.alpha".format(x)
        if A_name in lora.keys():
            alpha = None
            if alpha_name in lora.keys():
                alpha = lora[alpha_name].item()
                loaded_keys.add(alpha_name)
            patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha)
            loaded_keys.add(A_name)
            loaded_keys.add(B_name)
    for x in lora.keys():
        if x not in loaded_keys:
            print("lora key not loaded", x)
    return patch_dict

def model_lora_keys(model, key_map={}):
    sdk = model.state_dict().keys()

    counter = 0
    for b in range(12):
        tk = "model.diffusion_model.input_blocks.{}.1".format(b)
        up_counter = 0
        for c in LORA_UNET_MAP:
            k = "{}.{}.weight".format(tk, c)
            if k in sdk:
                lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
150
                key_map[lora_key] = k
151
152
153
154
155
156
157
                up_counter += 1
        if up_counter >= 4:
            counter += 1
    for c in LORA_UNET_MAP:
        k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
        if k in sdk:
            lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
158
            key_map[lora_key] = k
159
160
161
162
163
164
165
166
    counter = 3
    for b in range(12):
        tk = "model.diffusion_model.output_blocks.{}.1".format(b)
        up_counter = 0
        for c in LORA_UNET_MAP:
            k = "{}.{}.weight".format(tk, c)
            if k in sdk:
                lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
167
                key_map[lora_key] = k
168
169
170
171
                up_counter += 1
        if up_counter >= 4:
            counter += 1
    counter = 0
comfyanonymous's avatar
comfyanonymous committed
172
    text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
173
    for b in range(24):
174
175
176
        for c in LORA_CLIP_MAP:
            k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
            if k in sdk:
comfyanonymous's avatar
comfyanonymous committed
177
                lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
178
                key_map[lora_key] = k
comfyanonymous's avatar
comfyanonymous committed
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    return key_map

class ModelPatcher:
    def __init__(self, model):
        self.model = model
        self.patches = []
        self.backup = {}

    def clone(self):
        n = ModelPatcher(self.model)
        n.patches = self.patches[:]
        return n

    def add_patches(self, patches, strength=1.0):
        p = {}
        model_sd = self.model.state_dict()
        for k in patches:
197
            if k in model_sd:
198
199
200
201
202
203
204
205
206
                p[k] = patches[k]
        self.patches += [(strength, p)]
        return p.keys()

    def patch_model(self):
        model_sd = self.model.state_dict()
        for p in self.patches:
            for k in p[1]:
                v = p[1][k]
207
                key = k
comfyanonymous's avatar
comfyanonymous committed
208
                if key not in model_sd:
209
210
211
                    print("could not patch. key doesn't exist in model:", k)
                    continue

comfyanonymous's avatar
comfyanonymous committed
212
213
214
                weight = model_sd[key]
                if key not in self.backup:
                    self.backup[key] = weight.clone()
215
216
217
218
219
220

                alpha = p[0]
                mat1 = v[0]
                mat2 = v[1]
                if v[2] is not None:
                    alpha *= v[2] / mat2.shape[0]
221
                weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        return self.model
    def unpatch_model(self):
        model_sd = self.model.state_dict()
        for k in self.backup:
            model_sd[k][:] = self.backup[k]
        self.backup = {}

def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
    key_map = model_lora_keys(model.model)
    key_map = model_lora_keys(clip.cond_stage_model, key_map)
    loaded = load_lora(lora_path, key_map)
    new_modelpatcher = model.clone()
    k = new_modelpatcher.add_patches(loaded, strength_model)
    new_clip = clip.clone()
    k1 = new_clip.add_patches(loaded, strength_clip)
    k = set(k)
    k1 = set(k1)
    for x in loaded:
        if (x not in k) and (x not in k1):
            print("NOT LOADED", x)

    return (new_modelpatcher, new_clip)
comfyanonymous's avatar
comfyanonymous committed
244
245
246


class CLIP:
247
248
249
    def __init__(self, config={}, embedding_directory=None, no_init=False):
        if no_init:
            return
comfyanonymous's avatar
comfyanonymous committed
250
        self.target_clip = config["target"]
251
252
253
254
255
        if "params" in config:
            params = config["params"]
        else:
            params = {}

comfyanonymous's avatar
comfyanonymous committed
256
257
258
259
260
261
        if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
            clip = sd2_clip.SD2ClipModel
            tokenizer = sd2_clip.SD2Tokenizer
        elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
            clip = sd1_clip.SD1ClipModel
            tokenizer = sd1_clip.SD1Tokenizer
262
263

        self.cond_stage_model = clip(**(params))
264
        self.tokenizer = tokenizer(embedding_directory=embedding_directory)
265
266
267
268
269
270
271
272
273
274
        self.patcher = ModelPatcher(self.cond_stage_model)

    def clone(self):
        n = CLIP(no_init=True)
        n.target_clip = self.target_clip
        n.patcher = self.patcher.clone()
        n.cond_stage_model = self.cond_stage_model
        n.tokenizer = self.tokenizer
        return n

275
276
277
    def load_from_state_dict(self, sd):
        self.cond_stage_model.transformer.load_state_dict(sd, strict=False)

278
279
    def add_patches(self, patches, strength=1.0):
        return self.patcher.add_patches(patches, strength)
comfyanonymous's avatar
comfyanonymous committed
280

281
282
283
    def clip_layer(self, layer_idx):
        return self.cond_stage_model.clip_layer(layer_idx)

comfyanonymous's avatar
comfyanonymous committed
284
285
    def encode(self, text):
        tokens = self.tokenizer.tokenize_with_weights(text)
286
287
288
289
290
291
292
        try:
            self.patcher.patch_model()
            cond = self.cond_stage_model.encode_token_weights(tokens)
            self.patcher.unpatch_model()
        except Exception as e:
            self.patcher.unpatch_model()
            raise e
comfyanonymous's avatar
comfyanonymous committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        return cond

class VAE:
    def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None):
        if config is None:
            #default SD1.x/SD2.x VAE parameters
            ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
            self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
        else:
            self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
        self.first_stage_model = self.first_stage_model.eval()
        self.scale_factor = scale_factor
        self.device = device

    def decode(self, samples):
308
        model_management.unload_model()
comfyanonymous's avatar
comfyanonymous committed
309
310
311
312
313
314
315
316
317
        self.first_stage_model = self.first_stage_model.to(self.device)
        samples = samples.to(self.device)
        pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
        pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
        self.first_stage_model = self.first_stage_model.cpu()
        pixel_samples = pixel_samples.cpu().movedim(1,-1)
        return pixel_samples

    def encode(self, pixel_samples):
318
        model_management.unload_model()
comfyanonymous's avatar
comfyanonymous committed
319
320
321
322
323
324
325
        self.first_stage_model = self.first_stage_model.to(self.device)
        pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
        samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
        self.first_stage_model = self.first_stage_model.cpu()
        samples = samples.cpu()
        return samples

326
327
328
329
330
331
332
333
334
335
def load_clip(ckpt_path, embedding_directory=None):
    clip_data = load_torch_file(ckpt_path)
    config = {}
    if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
        config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
    else:
        config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
    clip = CLIP(config=config, embedding_directory=embedding_directory)
    clip.load_from_state_dict(clip_data)
    return clip
comfyanonymous's avatar
comfyanonymous committed
336

337
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
comfyanonymous's avatar
comfyanonymous committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    config = OmegaConf.load(config_path)
    model_config_params = config['model']['params']
    clip_config = model_config_params['cond_stage_config']
    scale_factor = model_config_params['scale_factor']
    vae_config = model_config_params['first_stage_config']

    clip = None
    vae = None

    class WeightsLoader(torch.nn.Module):
        pass

    w = WeightsLoader()
    load_state_dict_to = []
    if output_vae:
        vae = VAE(scale_factor=scale_factor, config=vae_config)
        w.first_stage_model = vae.first_stage_model
        load_state_dict_to = [w]

    if output_clip:
358
        clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
comfyanonymous's avatar
comfyanonymous committed
359
360
361
362
        w.cond_stage_model = clip.cond_stage_model
        load_state_dict_to = [w]

    model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
363
    return (ModelPatcher(model), clip, vae)