model_patcher.py 20.1 KB
Newer Older
1
2
3
import torch
import copy
import inspect
4
import logging
5
import uuid
6
7

import comfy.utils
8
import comfy.model_management
9

10
11
12
13
14
15
16
17
18
19
20
21
def apply_weight_decompose(dora_scale, weight):
    weight_norm = (
        weight.transpose(0, 1)
        .reshape(weight.shape[1], -1)
        .norm(dim=1, keepdim=True)
        .reshape(weight.shape[1], *[1] * (weight.dim() - 1))
        .transpose(0, 1)
    )

    return weight * (dora_scale / weight_norm)


22
class ModelPatcher:
23
    def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
24
25
26
27
        self.size = size
        self.model = model
        self.patches = {}
        self.backup = {}
28
29
        self.object_patches = {}
        self.object_patches_backup = {}
30
31
32
33
34
35
36
37
38
        self.model_options = {"transformer_options":{}}
        self.model_size()
        self.load_device = load_device
        self.offload_device = offload_device
        if current_device is None:
            self.current_device = self.offload_device
        else:
            self.current_device = current_device

39
        self.weight_inplace_update = weight_inplace_update
40
        self.model_lowvram = False
41
        self.patches_uuid = uuid.uuid4()
42

43
44
45
46
    def model_size(self):
        if self.size > 0:
            return self.size
        model_sd = self.model.state_dict()
47
        self.size = comfy.model_management.module_size(self.model)
48
        self.model_keys = set(model_sd.keys())
49
        return self.size
50
51

    def clone(self):
comfyanonymous's avatar
comfyanonymous committed
52
        n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
53
54
55
        n.patches = {}
        for k in self.patches:
            n.patches[k] = self.patches[k][:]
56
        n.patches_uuid = self.patches_uuid
57

58
        n.object_patches = self.object_patches.copy()
59
60
        n.model_options = copy.deepcopy(self.model_options)
        n.model_keys = self.model_keys
61
62
        n.backup = self.backup
        n.object_patches_backup = self.object_patches_backup
63
64
65
66
67
68
69
        return n

    def is_clone(self, other):
        if hasattr(other, 'model') and self.model is other.model:
            return True
        return False

70
71
72
73
74
75
76
77
78
79
80
81
82
    def clone_has_same_weights(self, clone):
        if not self.is_clone(clone):
            return False

        if len(self.patches) == 0 and len(clone.patches) == 0:
            return True

        if self.patches_uuid == clone.patches_uuid:
            if len(self.patches) != len(clone.patches):
                logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
            else:
                return True

83
84
85
    def memory_required(self, input_shape):
        return self.model.memory_required(input_shape=input_shape)

86
    def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
87
88
89
90
        if len(inspect.signature(sampler_cfg_function).parameters) == 3:
            self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
        else:
            self.model_options["sampler_cfg_function"] = sampler_cfg_function
91
92
        if disable_cfg1_optimization:
            self.model_options["disable_cfg1_optimization"] = True
93

94
    def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
95
        self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
96
97
        if disable_cfg1_optimization:
            self.model_options["disable_cfg1_optimization"] = True
98

99
100
101
    def set_model_unet_function_wrapper(self, unet_wrapper_function):
        self.model_options["model_function_wrapper"] = unet_wrapper_function

102
103
104
    def set_model_denoise_mask_function(self, denoise_mask_function):
        self.model_options["denoise_mask_function"] = denoise_mask_function

105
106
107
108
109
110
    def set_model_patch(self, patch, name):
        to = self.model_options["transformer_options"]
        if "patches" not in to:
            to["patches"] = {}
        to["patches"][name] = to["patches"].get(name, []) + [patch]

111
    def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
112
113
114
115
116
        to = self.model_options["transformer_options"]
        if "patches_replace" not in to:
            to["patches_replace"] = {}
        if name not in to["patches_replace"]:
            to["patches_replace"][name] = {}
117
118
119
120
121
        if transformer_index is not None:
            block = (block_name, number, transformer_index)
        else:
            block = (block_name, number)
        to["patches_replace"][name][block] = patch
122
123
124
125
126
127
128

    def set_model_attn1_patch(self, patch):
        self.set_model_patch(patch, "attn1_patch")

    def set_model_attn2_patch(self, patch):
        self.set_model_patch(patch, "attn2_patch")

129
130
    def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
        self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
131

132
133
    def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
        self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
134
135
136
137
138
139
140

    def set_model_attn1_output_patch(self, patch):
        self.set_model_patch(patch, "attn1_output_patch")

    def set_model_attn2_output_patch(self, patch):
        self.set_model_patch(patch, "attn2_output_patch")

141
142
143
    def set_model_input_block_patch(self, patch):
        self.set_model_patch(patch, "input_block_patch")

144
145
146
    def set_model_input_block_patch_after_skip(self, patch):
        self.set_model_patch(patch, "input_block_patch_after_skip")

147
148
149
    def set_model_output_block_patch(self, patch):
        self.set_model_patch(patch, "output_block_patch")

150
151
152
    def add_object_patch(self, name, obj):
        self.object_patches[name] = obj

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    def model_patches_to(self, device):
        to = self.model_options["transformer_options"]
        if "patches" in to:
            patches = to["patches"]
            for name in patches:
                patch_list = patches[name]
                for i in range(len(patch_list)):
                    if hasattr(patch_list[i], "to"):
                        patch_list[i] = patch_list[i].to(device)
        if "patches_replace" in to:
            patches = to["patches_replace"]
            for name in patches:
                patch_list = patches[name]
                for k in patch_list:
                    if hasattr(patch_list[k], "to"):
                        patch_list[k] = patch_list[k].to(device)
169
170
        if "model_function_wrapper" in self.model_options:
            wrap_func = self.model_options["model_function_wrapper"]
171
            if hasattr(wrap_func, "to"):
172
                self.model_options["model_function_wrapper"] = wrap_func.to(device)
173
174
175
176
177
178
179
180
181
182
183
184
185
186

    def model_dtype(self):
        if hasattr(self.model, "get_dtype"):
            return self.model.get_dtype()

    def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
        p = set()
        for k in patches:
            if k in self.model_keys:
                p.add(k)
                current_patches = self.patches.get(k, [])
                current_patches.append((strength_patch, patches[k], strength_model))
                self.patches[k] = current_patches

187
        self.patches_uuid = uuid.uuid4()
188
189
190
        return list(p)

    def get_key_patches(self, filter_prefix=None):
comfyanonymous's avatar
comfyanonymous committed
191
        comfy.model_management.unload_model_clones(self)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        model_sd = self.model_state_dict()
        p = {}
        for k in model_sd:
            if filter_prefix is not None:
                if not k.startswith(filter_prefix):
                    continue
            if k in self.patches:
                p[k] = [model_sd[k]] + self.patches[k]
            else:
                p[k] = (model_sd[k],)
        return p

    def model_state_dict(self, filter_prefix=None):
        sd = self.model.state_dict()
        keys = list(sd.keys())
        if filter_prefix is not None:
            for k in keys:
                if not k.startswith(filter_prefix):
                    sd.pop(k)
        return sd

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    def patch_weight_to_device(self, key, device_to=None):
        if key not in self.patches:
            return

        weight = comfy.utils.get_attr(self.model, key)

        inplace_update = self.weight_inplace_update

        if key not in self.backup:
            self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)

        if device_to is not None:
            temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
        else:
            temp_weight = weight.to(torch.float32, copy=True)
        out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
        if inplace_update:
            comfy.utils.copy_to_param(self.model, key, out_weight)
        else:
            comfy.utils.set_attr_param(self.model, key, out_weight)

234
    def patch_model(self, device_to=None, patch_weights=True):
235
        for k in self.object_patches:
236
            old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
237
238
239
            if k not in self.object_patches_backup:
                self.object_patches_backup[k] = old

240
241
242
243
        if patch_weights:
            model_sd = self.model_state_dict()
            for key in self.patches:
                if key not in model_sd:
244
                    logging.warning("could not patch. key doesn't exist in model: {}".format(key))
245
                    continue
246

247
                self.patch_weight_to_device(key, device_to)
248

249
250
251
            if device_to is not None:
                self.model.to(device_to)
                self.current_device = device_to
252
253
254

        return self.model

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
        self.patch_model(device_to, patch_weights=False)

        logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
        class LowVramPatch:
            def __init__(self, key, model_patcher):
                self.key = key
                self.model_patcher = model_patcher
            def __call__(self, weight):
                return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)

        mem_counter = 0
        for n, m in self.model.named_modules():
            lowvram_weight = False
            if hasattr(m, "comfy_cast_weights"):
                module_mem = comfy.model_management.module_size(m)
                if mem_counter + module_mem >= lowvram_model_memory:
                    lowvram_weight = True

            weight_key = "{}.weight".format(n)
            bias_key = "{}.bias".format(n)

            if lowvram_weight:
                if weight_key in self.patches:
                    m.weight_function = LowVramPatch(weight_key, self)
                if bias_key in self.patches:
                    m.bias_function = LowVramPatch(weight_key, self)

                m.prev_comfy_cast_weights = m.comfy_cast_weights
                m.comfy_cast_weights = True
            else:
                if hasattr(m, "weight"):
                    self.patch_weight_to_device(weight_key, device_to)
                    self.patch_weight_to_device(bias_key, device_to)
                    m.to(device_to)
                    mem_counter += comfy.model_management.module_size(m)
                    logging.debug("lowvram: loaded module regularly {}".format(m))

        self.model_lowvram = True
        return self.model

296
297
298
299
300
301
302
303
304
305
306
307
308
    def calculate_weight(self, patches, weight, key):
        for p in patches:
            alpha = p[0]
            v = p[1]
            strength_model = p[2]

            if strength_model != 1.0:
                weight *= strength_model

            if isinstance(v, list):
                v = (self.calculate_weight(v[1:], v[0].clone(), key), )

            if len(v) == 1:
comfyanonymous's avatar
comfyanonymous committed
309
310
311
312
313
314
                patch_type = "diff"
            elif len(v) == 2:
                patch_type = v[0]
                v = v[1]

            if patch_type == "diff":
315
316
317
                w1 = v[0]
                if alpha != 0.0:
                    if w1.shape != weight.shape:
318
                        logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
319
                    else:
320
                        weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
comfyanonymous's avatar
comfyanonymous committed
321
            elif patch_type == "lora": #lora/locon
322
323
                mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
                mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
324
                dora_scale = v[4]
325
326
327
328
                if v[2] is not None:
                    alpha *= v[2] / mat2.shape[0]
                if v[3] is not None:
                    #locon mid weights, hopefully the math is fine because I didn't properly test it
329
                    mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
330
331
332
333
                    final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
                    mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
                try:
                    weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
334
335
                    if dora_scale is not None:
                        weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
336
                except Exception as e:
337
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
338
            elif patch_type == "lokr":
339
340
341
342
343
344
345
                w1 = v[0]
                w2 = v[1]
                w1_a = v[3]
                w1_b = v[4]
                w2_a = v[5]
                w2_b = v[6]
                t2 = v[7]
346
                dora_scale = v[8]
347
348
349
350
                dim = None

                if w1 is None:
                    dim = w1_b.shape[0]
351
352
                    w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
353
                else:
354
                    w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
355
356
357
358

                if w2 is None:
                    dim = w2_b.shape[0]
                    if t2 is None:
359
360
                        w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
361
                    else:
362
363
364
365
                        w2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                          comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
366
                else:
367
                    w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
368
369
370
371
372
373
374
375

                if len(w2.shape) == 4:
                    w1 = w1.unsqueeze(2).unsqueeze(2)
                if v[2] is not None and dim is not None:
                    alpha *= v[2] / dim

                try:
                    weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
376
377
                    if dora_scale is not None:
                        weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
378
                except Exception as e:
379
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
380
            elif patch_type == "loha":
381
382
383
384
385
386
                w1a = v[0]
                w1b = v[1]
                if v[2] is not None:
                    alpha *= v[2] / w1b.shape[0]
                w2a = v[3]
                w2b = v[4]
387
                dora_scale = v[7]
388
389
390
                if v[5] is not None: #cp decomposition
                    t1 = v[5]
                    t2 = v[6]
391
392
393
394
395
396
397
398
399
                    m1 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))

                    m2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
400
                else:
401
402
403
404
                    m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
                    m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
405
406
407

                try:
                    weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
408
409
                    if dora_scale is not None:
                        weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
410
                except Exception as e:
411
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
412
413
414
415
            elif patch_type == "glora":
                if v[4] is not None:
                    alpha *= v[4] / v[0].shape[0]

416
417
                dora_scale = v[5]

comfyanonymous's avatar
comfyanonymous committed
418
419
420
421
422
                a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
                a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
                b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
                b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

423
424
                try:
                    weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
425
426
                    if dora_scale is not None:
                        weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
427
428
                except Exception as e:
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
429
            else:
430
                logging.warning("patch type not recognized {} {}".format(patch_type, key))
431
432
433

        return weight

434
435
436
437
438
439
440
441
442
    def unpatch_model(self, device_to=None, unpatch_weights=True):
        if unpatch_weights:
            if self.model_lowvram:
                for m in self.model.modules():
                    if hasattr(m, "prev_comfy_cast_weights"):
                        m.comfy_cast_weights = m.prev_comfy_cast_weights
                        del m.prev_comfy_cast_weights
                    m.weight_function = None
                    m.bias_function = None
443

444
                self.model_lowvram = False
445

446
            keys = list(self.backup.keys())
447

448
449
450
451
452
453
            if self.weight_inplace_update:
                for k in keys:
                    comfy.utils.copy_to_param(self.model, k, self.backup[k])
            else:
                for k in keys:
                    comfy.utils.set_attr_param(self.model, k, self.backup[k])
454

455
            self.backup.clear()
456

457
458
459
            if device_to is not None:
                self.model.to(device_to)
                self.current_device = device_to
460
461
462

        keys = list(self.object_patches_backup.keys())
        for k in keys:
463
            comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
464
465

        self.object_patches_backup = {}