sd.py 12.2 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
8
import torch

import sd1_clip
import sd2_clip
from ldm.util import instantiate_from_config
from ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf

9
def load_torch_file(ckpt):
comfyanonymous's avatar
comfyanonymous committed
10
11
12
13
14
15
16
    if ckpt.lower().endswith(".safetensors"):
        import safetensors.torch
        sd = safetensors.torch.load_file(ckpt, device="cpu")
    else:
        pl_sd = torch.load(ckpt, map_location="cpu")
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
17
18
19
20
        if "state_dict" in pl_sd:
            sd = pl_sd["state_dict"]
        else:
            sd = pl_sd
21
22
23
24
25
26
    return sd

def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
    print(f"Loading model from {ckpt}")

    sd = load_torch_file(ckpt)
comfyanonymous's avatar
comfyanonymous committed
27
28
29
30
31
32
33
34
35
36
37
    model = instantiate_from_config(config.model)

    m, u = model.load_state_dict(sd, strict=False)

    k = list(sd.keys())
    for x in k:
        # print(x)
        if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
            y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
            sd[y] = sd.pop(x)

comfyanonymous's avatar
comfyanonymous committed
38
39
40
41
    if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
        ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids']
        if ids.dtype == torch.float32:
            sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
42

comfyanonymous's avatar
comfyanonymous committed
43
44
45
46
47
48
49
50
51
52
53
54
55
    for x in load_state_dict_to:
        x.load_state_dict(sd, strict=False)

    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.eval()
    return model

56
57
58
59
60
61
62
63
64
LORA_CLIP_MAP = {
    "mlp.fc1": "mlp_fc1",
    "mlp.fc2": "mlp_fc2",
    "self_attn.k_proj": "self_attn_k_proj",
    "self_attn.q_proj": "self_attn_q_proj",
    "self_attn.v_proj": "self_attn_v_proj",
    "self_attn.out_proj": "self_attn_out_proj",
}

comfyanonymous's avatar
comfyanonymous committed
65
66
67
68
69
70
LORA_CLIP2_MAP = {
    "mlp.c_fc": "mlp_fc1",
    "mlp.c_proj": "mlp_fc2",
    "attn.out_proj": "self_attn_out_proj",
}

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
LORA_UNET_MAP = {
    "proj_in": "proj_in",
    "proj_out": "proj_out",
    "transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
    "transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k",
    "transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v",
    "transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0",
    "transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q",
    "transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k",
    "transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v",
    "transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0",
    "transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj",
    "transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
}


def load_lora(path, to_load):
    lora = load_torch_file(path)
    patch_dict = {}
    loaded_keys = set()
    for x in to_load:
        A_name = "{}.lora_up.weight".format(x)
        B_name = "{}.lora_down.weight".format(x)
        alpha_name = "{}.alpha".format(x)
        if A_name in lora.keys():
            alpha = None
            if alpha_name in lora.keys():
                alpha = lora[alpha_name].item()
                loaded_keys.add(alpha_name)
            patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha)
            loaded_keys.add(A_name)
            loaded_keys.add(B_name)
    for x in lora.keys():
        if x not in loaded_keys:
            print("lora key not loaded", x)
    return patch_dict

def model_lora_keys(model, key_map={}):
    sdk = model.state_dict().keys()

    counter = 0
    for b in range(12):
        tk = "model.diffusion_model.input_blocks.{}.1".format(b)
        up_counter = 0
        for c in LORA_UNET_MAP:
            k = "{}.{}.weight".format(tk, c)
            if k in sdk:
                lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
comfyanonymous's avatar
comfyanonymous committed
119
                key_map[lora_key] = (k, 0)
120
121
122
123
124
125
126
                up_counter += 1
        if up_counter >= 4:
            counter += 1
    for c in LORA_UNET_MAP:
        k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
        if k in sdk:
            lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
comfyanonymous's avatar
comfyanonymous committed
127
            key_map[lora_key] = (k, 0)
128
129
130
131
132
133
134
135
    counter = 3
    for b in range(12):
        tk = "model.diffusion_model.output_blocks.{}.1".format(b)
        up_counter = 0
        for c in LORA_UNET_MAP:
            k = "{}.{}.weight".format(tk, c)
            if k in sdk:
                lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
comfyanonymous's avatar
comfyanonymous committed
136
                key_map[lora_key] = (k, 0)
137
138
139
140
                up_counter += 1
        if up_counter >= 4:
            counter += 1
    counter = 0
comfyanonymous's avatar
comfyanonymous committed
141
    text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
142
143
144
145
    for b in range(12):
        for c in LORA_CLIP_MAP:
            k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
            if k in sdk:
comfyanonymous's avatar
comfyanonymous committed
146
147
148
149
150
151
152
153
154
155
                lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
                key_map[lora_key] = (k, 0)
    for b in range(24):
        for c in LORA_CLIP2_MAP:
            k = "model.transformer.resblocks.{}.{}.weight".format(b, c)
            if k in sdk:
                lora_key = text_model_lora_key.format(b, LORA_CLIP2_MAP[c])
                key_map[lora_key] = (k, 0)
        k = "model.transformer.resblocks.{}.attn.in_proj_weight".format(b)
        if k in sdk:
comfyanonymous's avatar
comfyanonymous committed
156
157
            key_map[text_model_lora_key.format(b, "self_attn_q_proj")] = (k, 0)
            key_map[text_model_lora_key.format(b, "self_attn_k_proj")] = (k, 1)
comfyanonymous's avatar
comfyanonymous committed
158
159
            key_map[text_model_lora_key.format(b, "self_attn_v_proj")] = (k, 2)

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    return key_map

class ModelPatcher:
    def __init__(self, model):
        self.model = model
        self.patches = []
        self.backup = {}

    def clone(self):
        n = ModelPatcher(self.model)
        n.patches = self.patches[:]
        return n

    def add_patches(self, patches, strength=1.0):
        p = {}
        model_sd = self.model.state_dict()
        for k in patches:
comfyanonymous's avatar
comfyanonymous committed
177
            if k[0] in model_sd:
178
179
180
181
182
183
184
185
186
                p[k] = patches[k]
        self.patches += [(strength, p)]
        return p.keys()

    def patch_model(self):
        model_sd = self.model.state_dict()
        for p in self.patches:
            for k in p[1]:
                v = p[1][k]
comfyanonymous's avatar
comfyanonymous committed
187
188
189
                key = k[0]
                index = k[1]
                if key not in model_sd:
190
191
192
                    print("could not patch. key doesn't exist in model:", k)
                    continue

comfyanonymous's avatar
comfyanonymous committed
193
194
195
                weight = model_sd[key]
                if key not in self.backup:
                    self.backup[key] = weight.clone()
196
197
198
199
200
201

                alpha = p[0]
                mat1 = v[0]
                mat2 = v[1]
                if v[2] is not None:
                    alpha *= v[2] / mat2.shape[0]
comfyanonymous's avatar
comfyanonymous committed
202
203
204
205
                calc = (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float()))
                if len(weight.shape) > 2:
                    calc = calc.reshape(weight.shape)
                weight[index * mat1.shape[0]:(index + 1) * mat1.shape[0]] += calc.type(weight.dtype).to(weight.device)
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        return self.model
    def unpatch_model(self):
        model_sd = self.model.state_dict()
        for k in self.backup:
            model_sd[k][:] = self.backup[k]
        self.backup = {}

def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
    key_map = model_lora_keys(model.model)
    key_map = model_lora_keys(clip.cond_stage_model, key_map)
    loaded = load_lora(lora_path, key_map)
    new_modelpatcher = model.clone()
    k = new_modelpatcher.add_patches(loaded, strength_model)
    new_clip = clip.clone()
    k1 = new_clip.add_patches(loaded, strength_clip)
    k = set(k)
    k1 = set(k1)
    for x in loaded:
        if (x not in k) and (x not in k1):
            print("NOT LOADED", x)

    return (new_modelpatcher, new_clip)
comfyanonymous's avatar
comfyanonymous committed
228
229
230


class CLIP:
231
232
233
    def __init__(self, config={}, embedding_directory=None, no_init=False):
        if no_init:
            return
comfyanonymous's avatar
comfyanonymous committed
234
        self.target_clip = config["target"]
235
236
237
238
239
240
241
        if "params" in config:
            params = config["params"]
        else:
            params = {}

        tokenizer_params = {}

comfyanonymous's avatar
comfyanonymous committed
242
243
244
245
246
247
        if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
            clip = sd2_clip.SD2ClipModel
            tokenizer = sd2_clip.SD2Tokenizer
        elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
            clip = sd1_clip.SD1ClipModel
            tokenizer = sd1_clip.SD1Tokenizer
248
249
250
251
            tokenizer_params['embedding_directory'] = embedding_directory

        self.cond_stage_model = clip(**(params))
        self.tokenizer = tokenizer(**(tokenizer_params))
252
253
254
255
256
257
258
259
260
261
262
263
        self.patcher = ModelPatcher(self.cond_stage_model)

    def clone(self):
        n = CLIP(no_init=True)
        n.target_clip = self.target_clip
        n.patcher = self.patcher.clone()
        n.cond_stage_model = self.cond_stage_model
        n.tokenizer = self.tokenizer
        return n

    def add_patches(self, patches, strength=1.0):
        return self.patcher.add_patches(patches, strength)
comfyanonymous's avatar
comfyanonymous committed
264
265
266

    def encode(self, text):
        tokens = self.tokenizer.tokenize_with_weights(text)
267
268
269
270
271
272
273
        try:
            self.patcher.patch_model()
            cond = self.cond_stage_model.encode_token_weights(tokens)
            self.patcher.unpatch_model()
        except Exception as e:
            self.patcher.unpatch_model()
            raise e
comfyanonymous's avatar
comfyanonymous committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        return cond

class VAE:
    def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None):
        if config is None:
            #default SD1.x/SD2.x VAE parameters
            ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
            self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
        else:
            self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
        self.first_stage_model = self.first_stage_model.eval()
        self.scale_factor = scale_factor
        self.device = device

    def decode(self, samples):
        self.first_stage_model = self.first_stage_model.to(self.device)
        samples = samples.to(self.device)
        pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
        pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
        self.first_stage_model = self.first_stage_model.cpu()
        pixel_samples = pixel_samples.cpu().movedim(1,-1)
        return pixel_samples

    def encode(self, pixel_samples):
        self.first_stage_model = self.first_stage_model.to(self.device)
        pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
        samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
        self.first_stage_model = self.first_stage_model.cpu()
        samples = samples.cpu()
        return samples


306
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
comfyanonymous's avatar
comfyanonymous committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    config = OmegaConf.load(config_path)
    model_config_params = config['model']['params']
    clip_config = model_config_params['cond_stage_config']
    scale_factor = model_config_params['scale_factor']
    vae_config = model_config_params['first_stage_config']

    clip = None
    vae = None

    class WeightsLoader(torch.nn.Module):
        pass

    w = WeightsLoader()
    load_state_dict_to = []
    if output_vae:
        vae = VAE(scale_factor=scale_factor, config=vae_config)
        w.first_stage_model = vae.first_stage_model
        load_state_dict_to = [w]

    if output_clip:
327
        clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
comfyanonymous's avatar
comfyanonymous committed
328
329
330
331
        w.cond_stage_model = clip.cond_stage_model
        load_state_dict_to = [w]

    model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
332
    return (ModelPatcher(model), clip, vae)