"template/llama3-instruct.json" did not exist on "b0135f4b9b176eab9155b660d04c9ca2a1ec2341"
model_patcher.py 20.9 KB
Newer Older
1
2
3
import torch
import copy
import inspect
4
import logging
5
import uuid
6
7

import comfy.utils
8
import comfy.model_management
9

10
11
12
13
14
15
16
17
18
19
20
def apply_weight_decompose(dora_scale, weight):
    weight_norm = (
        weight.transpose(0, 1)
        .reshape(weight.shape[1], -1)
        .norm(dim=1, keepdim=True)
        .reshape(weight.shape[1], *[1] * (weight.dim() - 1))
        .transpose(0, 1)
    )

    return weight * (dora_scale / weight_norm)

comfyanonymous's avatar
comfyanonymous committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
    to = model_options["transformer_options"].copy()

    if "patches_replace" not in to:
        to["patches_replace"] = {}
    else:
        to["patches_replace"] = to["patches_replace"].copy()

    if name not in to["patches_replace"]:
        to["patches_replace"][name] = {}
    else:
        to["patches_replace"][name] = to["patches_replace"][name].copy()

    if transformer_index is not None:
        block = (block_name, number, transformer_index)
    else:
        block = (block_name, number)
    to["patches_replace"][name][block] = patch
    model_options["transformer_options"] = to
    return model_options
41

42
class ModelPatcher:
43
    def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
44
45
46
47
        self.size = size
        self.model = model
        self.patches = {}
        self.backup = {}
48
49
        self.object_patches = {}
        self.object_patches_backup = {}
50
51
52
53
54
55
56
57
58
        self.model_options = {"transformer_options":{}}
        self.model_size()
        self.load_device = load_device
        self.offload_device = offload_device
        if current_device is None:
            self.current_device = self.offload_device
        else:
            self.current_device = current_device

59
        self.weight_inplace_update = weight_inplace_update
60
        self.model_lowvram = False
61
        self.patches_uuid = uuid.uuid4()
62

63
64
65
66
    def model_size(self):
        if self.size > 0:
            return self.size
        model_sd = self.model.state_dict()
67
        self.size = comfy.model_management.module_size(self.model)
68
        self.model_keys = set(model_sd.keys())
69
        return self.size
70
71

    def clone(self):
comfyanonymous's avatar
comfyanonymous committed
72
        n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
73
74
75
        n.patches = {}
        for k in self.patches:
            n.patches[k] = self.patches[k][:]
76
        n.patches_uuid = self.patches_uuid
77

78
        n.object_patches = self.object_patches.copy()
79
80
        n.model_options = copy.deepcopy(self.model_options)
        n.model_keys = self.model_keys
81
82
        n.backup = self.backup
        n.object_patches_backup = self.object_patches_backup
83
84
85
86
87
88
89
        return n

    def is_clone(self, other):
        if hasattr(other, 'model') and self.model is other.model:
            return True
        return False

90
91
92
93
94
95
96
97
98
99
100
101
102
    def clone_has_same_weights(self, clone):
        if not self.is_clone(clone):
            return False

        if len(self.patches) == 0 and len(clone.patches) == 0:
            return True

        if self.patches_uuid == clone.patches_uuid:
            if len(self.patches) != len(clone.patches):
                logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
            else:
                return True

103
104
105
    def memory_required(self, input_shape):
        return self.model.memory_required(input_shape=input_shape)

106
    def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
107
108
109
110
        if len(inspect.signature(sampler_cfg_function).parameters) == 3:
            self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
        else:
            self.model_options["sampler_cfg_function"] = sampler_cfg_function
111
112
        if disable_cfg1_optimization:
            self.model_options["disable_cfg1_optimization"] = True
113

114
    def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
115
        self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
116
117
        if disable_cfg1_optimization:
            self.model_options["disable_cfg1_optimization"] = True
118

119
120
121
    def set_model_unet_function_wrapper(self, unet_wrapper_function):
        self.model_options["model_function_wrapper"] = unet_wrapper_function

122
123
124
    def set_model_denoise_mask_function(self, denoise_mask_function):
        self.model_options["denoise_mask_function"] = denoise_mask_function

125
126
127
128
129
130
    def set_model_patch(self, patch, name):
        to = self.model_options["transformer_options"]
        if "patches" not in to:
            to["patches"] = {}
        to["patches"][name] = to["patches"].get(name, []) + [patch]

131
    def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
comfyanonymous's avatar
comfyanonymous committed
132
        self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
133
134
135
136
137
138
139

    def set_model_attn1_patch(self, patch):
        self.set_model_patch(patch, "attn1_patch")

    def set_model_attn2_patch(self, patch):
        self.set_model_patch(patch, "attn2_patch")

140
141
    def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
        self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
142

143
144
    def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
        self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
145
146
147
148
149
150
151

    def set_model_attn1_output_patch(self, patch):
        self.set_model_patch(patch, "attn1_output_patch")

    def set_model_attn2_output_patch(self, patch):
        self.set_model_patch(patch, "attn2_output_patch")

152
153
154
    def set_model_input_block_patch(self, patch):
        self.set_model_patch(patch, "input_block_patch")

155
156
157
    def set_model_input_block_patch_after_skip(self, patch):
        self.set_model_patch(patch, "input_block_patch_after_skip")

158
159
160
    def set_model_output_block_patch(self, patch):
        self.set_model_patch(patch, "output_block_patch")

161
162
163
    def add_object_patch(self, name, obj):
        self.object_patches[name] = obj

164
165
166
167
    def get_model_object(self, name):
        if name in self.object_patches:
            return self.object_patches[name]
        else:
168
169
170
171
            if name in self.object_patches_backup:
                return self.object_patches_backup[name]
            else:
                return comfy.utils.get_attr(self.model, name)
172

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def model_patches_to(self, device):
        to = self.model_options["transformer_options"]
        if "patches" in to:
            patches = to["patches"]
            for name in patches:
                patch_list = patches[name]
                for i in range(len(patch_list)):
                    if hasattr(patch_list[i], "to"):
                        patch_list[i] = patch_list[i].to(device)
        if "patches_replace" in to:
            patches = to["patches_replace"]
            for name in patches:
                patch_list = patches[name]
                for k in patch_list:
                    if hasattr(patch_list[k], "to"):
                        patch_list[k] = patch_list[k].to(device)
189
190
        if "model_function_wrapper" in self.model_options:
            wrap_func = self.model_options["model_function_wrapper"]
191
            if hasattr(wrap_func, "to"):
192
                self.model_options["model_function_wrapper"] = wrap_func.to(device)
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    def model_dtype(self):
        if hasattr(self.model, "get_dtype"):
            return self.model.get_dtype()

    def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
        p = set()
        for k in patches:
            if k in self.model_keys:
                p.add(k)
                current_patches = self.patches.get(k, [])
                current_patches.append((strength_patch, patches[k], strength_model))
                self.patches[k] = current_patches

207
        self.patches_uuid = uuid.uuid4()
208
209
210
        return list(p)

    def get_key_patches(self, filter_prefix=None):
comfyanonymous's avatar
comfyanonymous committed
211
        comfy.model_management.unload_model_clones(self)
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        model_sd = self.model_state_dict()
        p = {}
        for k in model_sd:
            if filter_prefix is not None:
                if not k.startswith(filter_prefix):
                    continue
            if k in self.patches:
                p[k] = [model_sd[k]] + self.patches[k]
            else:
                p[k] = (model_sd[k],)
        return p

    def model_state_dict(self, filter_prefix=None):
        sd = self.model.state_dict()
        keys = list(sd.keys())
        if filter_prefix is not None:
            for k in keys:
                if not k.startswith(filter_prefix):
                    sd.pop(k)
        return sd

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    def patch_weight_to_device(self, key, device_to=None):
        if key not in self.patches:
            return

        weight = comfy.utils.get_attr(self.model, key)

        inplace_update = self.weight_inplace_update

        if key not in self.backup:
            self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)

        if device_to is not None:
            temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
        else:
            temp_weight = weight.to(torch.float32, copy=True)
        out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
        if inplace_update:
            comfy.utils.copy_to_param(self.model, key, out_weight)
        else:
            comfy.utils.set_attr_param(self.model, key, out_weight)

254
    def patch_model(self, device_to=None, patch_weights=True):
255
        for k in self.object_patches:
256
            old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
257
258
259
            if k not in self.object_patches_backup:
                self.object_patches_backup[k] = old

260
261
262
263
        if patch_weights:
            model_sd = self.model_state_dict()
            for key in self.patches:
                if key not in model_sd:
264
                    logging.warning("could not patch. key doesn't exist in model: {}".format(key))
265
                    continue
266

267
                self.patch_weight_to_device(key, device_to)
268

269
270
271
            if device_to is not None:
                self.model.to(device_to)
                self.current_device = device_to
272
273
274

        return self.model

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
        self.patch_model(device_to, patch_weights=False)

        logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
        class LowVramPatch:
            def __init__(self, key, model_patcher):
                self.key = key
                self.model_patcher = model_patcher
            def __call__(self, weight):
                return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)

        mem_counter = 0
        for n, m in self.model.named_modules():
            lowvram_weight = False
            if hasattr(m, "comfy_cast_weights"):
                module_mem = comfy.model_management.module_size(m)
                if mem_counter + module_mem >= lowvram_model_memory:
                    lowvram_weight = True

            weight_key = "{}.weight".format(n)
            bias_key = "{}.bias".format(n)

            if lowvram_weight:
                if weight_key in self.patches:
                    m.weight_function = LowVramPatch(weight_key, self)
                if bias_key in self.patches:
kk-89's avatar
kk-89 committed
301
                    m.bias_function = LowVramPatch(bias_key, self)
302
303
304
305
306
307
308
309
310
311
312
313
314
315

                m.prev_comfy_cast_weights = m.comfy_cast_weights
                m.comfy_cast_weights = True
            else:
                if hasattr(m, "weight"):
                    self.patch_weight_to_device(weight_key, device_to)
                    self.patch_weight_to_device(bias_key, device_to)
                    m.to(device_to)
                    mem_counter += comfy.model_management.module_size(m)
                    logging.debug("lowvram: loaded module regularly {}".format(m))

        self.model_lowvram = True
        return self.model

316
317
318
319
320
321
322
323
324
325
326
327
328
    def calculate_weight(self, patches, weight, key):
        for p in patches:
            alpha = p[0]
            v = p[1]
            strength_model = p[2]

            if strength_model != 1.0:
                weight *= strength_model

            if isinstance(v, list):
                v = (self.calculate_weight(v[1:], v[0].clone(), key), )

            if len(v) == 1:
comfyanonymous's avatar
comfyanonymous committed
329
330
331
332
333
334
                patch_type = "diff"
            elif len(v) == 2:
                patch_type = v[0]
                v = v[1]

            if patch_type == "diff":
335
336
337
                w1 = v[0]
                if alpha != 0.0:
                    if w1.shape != weight.shape:
338
                        logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
339
                    else:
340
                        weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
comfyanonymous's avatar
comfyanonymous committed
341
            elif patch_type == "lora": #lora/locon
342
343
                mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
                mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
344
                dora_scale = v[4]
345
346
347
348
                if v[2] is not None:
                    alpha *= v[2] / mat2.shape[0]
                if v[3] is not None:
                    #locon mid weights, hopefully the math is fine because I didn't properly test it
349
                    mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
350
351
352
353
                    final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
                    mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
                try:
                    weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
354
355
                    if dora_scale is not None:
                        weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
356
                except Exception as e:
357
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
358
            elif patch_type == "lokr":
359
360
361
362
363
364
365
                w1 = v[0]
                w2 = v[1]
                w1_a = v[3]
                w1_b = v[4]
                w2_a = v[5]
                w2_b = v[6]
                t2 = v[7]
366
                dora_scale = v[8]
367
368
369
370
                dim = None

                if w1 is None:
                    dim = w1_b.shape[0]
371
372
                    w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
373
                else:
374
                    w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
375
376
377
378

                if w2 is None:
                    dim = w2_b.shape[0]
                    if t2 is None:
379
380
                        w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
381
                    else:
382
383
384
385
                        w2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                          comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
386
                else:
387
                    w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
388
389
390
391
392
393
394
395

                if len(w2.shape) == 4:
                    w1 = w1.unsqueeze(2).unsqueeze(2)
                if v[2] is not None and dim is not None:
                    alpha *= v[2] / dim

                try:
                    weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
396
397
                    if dora_scale is not None:
                        weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
398
                except Exception as e:
399
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
400
            elif patch_type == "loha":
401
402
403
404
405
406
                w1a = v[0]
                w1b = v[1]
                if v[2] is not None:
                    alpha *= v[2] / w1b.shape[0]
                w2a = v[3]
                w2b = v[4]
407
                dora_scale = v[7]
408
409
410
                if v[5] is not None: #cp decomposition
                    t1 = v[5]
                    t2 = v[6]
411
412
413
414
415
416
417
418
419
                    m1 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))

                    m2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
420
                else:
421
422
423
424
                    m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
                    m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
425
426
427

                try:
                    weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
428
429
                    if dora_scale is not None:
                        weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
430
                except Exception as e:
431
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
432
433
434
435
            elif patch_type == "glora":
                if v[4] is not None:
                    alpha *= v[4] / v[0].shape[0]

436
437
                dora_scale = v[5]

comfyanonymous's avatar
comfyanonymous committed
438
439
440
441
442
                a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
                a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
                b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
                b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

443
444
                try:
                    weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
445
446
                    if dora_scale is not None:
                        weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
447
448
                except Exception as e:
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
449
            else:
450
                logging.warning("patch type not recognized {} {}".format(patch_type, key))
451
452
453

        return weight

454
455
456
457
458
459
460
461
462
    def unpatch_model(self, device_to=None, unpatch_weights=True):
        if unpatch_weights:
            if self.model_lowvram:
                for m in self.model.modules():
                    if hasattr(m, "prev_comfy_cast_weights"):
                        m.comfy_cast_weights = m.prev_comfy_cast_weights
                        del m.prev_comfy_cast_weights
                    m.weight_function = None
                    m.bias_function = None
463

464
                self.model_lowvram = False
465

466
            keys = list(self.backup.keys())
467

468
469
470
471
472
473
            if self.weight_inplace_update:
                for k in keys:
                    comfy.utils.copy_to_param(self.model, k, self.backup[k])
            else:
                for k in keys:
                    comfy.utils.set_attr_param(self.model, k, self.backup[k])
474

475
            self.backup.clear()
476

477
478
479
            if device_to is not None:
                self.model.to(device_to)
                self.current_device = device_to
480
481
482

        keys = list(self.object_patches_backup.keys())
        for k in keys:
483
            comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
484

485
        self.object_patches_backup.clear()