model_patcher.py 18.9 KB
Newer Older
1
2
3
import torch
import copy
import inspect
4
import logging
5
import uuid
6
7

import comfy.utils
8
import comfy.model_management
9
10

class ModelPatcher:
11
    def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
12
13
14
15
        self.size = size
        self.model = model
        self.patches = {}
        self.backup = {}
16
17
        self.object_patches = {}
        self.object_patches_backup = {}
18
19
20
21
22
23
24
25
26
        self.model_options = {"transformer_options":{}}
        self.model_size()
        self.load_device = load_device
        self.offload_device = offload_device
        if current_device is None:
            self.current_device = self.offload_device
        else:
            self.current_device = current_device

27
        self.weight_inplace_update = weight_inplace_update
28
        self.model_lowvram = False
29
        self.patches_uuid = uuid.uuid4()
30

31
32
33
34
    def model_size(self):
        if self.size > 0:
            return self.size
        model_sd = self.model.state_dict()
35
        self.size = comfy.model_management.module_size(self.model)
36
        self.model_keys = set(model_sd.keys())
37
        return self.size
38
39

    def clone(self):
comfyanonymous's avatar
comfyanonymous committed
40
        n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
41
42
43
        n.patches = {}
        for k in self.patches:
            n.patches[k] = self.patches[k][:]
44
        n.patches_uuid = self.patches_uuid
45

46
        n.object_patches = self.object_patches.copy()
47
48
        n.model_options = copy.deepcopy(self.model_options)
        n.model_keys = self.model_keys
49
50
        n.backup = self.backup
        n.object_patches_backup = self.object_patches_backup
51
52
53
54
55
56
57
        return n

    def is_clone(self, other):
        if hasattr(other, 'model') and self.model is other.model:
            return True
        return False

58
59
60
61
62
63
64
65
66
67
68
69
70
    def clone_has_same_weights(self, clone):
        if not self.is_clone(clone):
            return False

        if len(self.patches) == 0 and len(clone.patches) == 0:
            return True

        if self.patches_uuid == clone.patches_uuid:
            if len(self.patches) != len(clone.patches):
                logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
            else:
                return True

71
72
73
    def memory_required(self, input_shape):
        return self.model.memory_required(input_shape=input_shape)

74
    def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
75
76
77
78
        if len(inspect.signature(sampler_cfg_function).parameters) == 3:
            self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
        else:
            self.model_options["sampler_cfg_function"] = sampler_cfg_function
79
80
        if disable_cfg1_optimization:
            self.model_options["disable_cfg1_optimization"] = True
81

82
    def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
83
        self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
84
85
        if disable_cfg1_optimization:
            self.model_options["disable_cfg1_optimization"] = True
86

87
88
89
    def set_model_unet_function_wrapper(self, unet_wrapper_function):
        self.model_options["model_function_wrapper"] = unet_wrapper_function

90
91
92
    def set_model_denoise_mask_function(self, denoise_mask_function):
        self.model_options["denoise_mask_function"] = denoise_mask_function

93
94
95
96
97
98
    def set_model_patch(self, patch, name):
        to = self.model_options["transformer_options"]
        if "patches" not in to:
            to["patches"] = {}
        to["patches"][name] = to["patches"].get(name, []) + [patch]

99
    def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
100
101
102
103
104
        to = self.model_options["transformer_options"]
        if "patches_replace" not in to:
            to["patches_replace"] = {}
        if name not in to["patches_replace"]:
            to["patches_replace"][name] = {}
105
106
107
108
109
        if transformer_index is not None:
            block = (block_name, number, transformer_index)
        else:
            block = (block_name, number)
        to["patches_replace"][name][block] = patch
110
111
112
113
114
115
116

    def set_model_attn1_patch(self, patch):
        self.set_model_patch(patch, "attn1_patch")

    def set_model_attn2_patch(self, patch):
        self.set_model_patch(patch, "attn2_patch")

117
118
    def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
        self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
119

120
121
    def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
        self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
122
123
124
125
126
127
128

    def set_model_attn1_output_patch(self, patch):
        self.set_model_patch(patch, "attn1_output_patch")

    def set_model_attn2_output_patch(self, patch):
        self.set_model_patch(patch, "attn2_output_patch")

129
130
131
    def set_model_input_block_patch(self, patch):
        self.set_model_patch(patch, "input_block_patch")

132
133
134
    def set_model_input_block_patch_after_skip(self, patch):
        self.set_model_patch(patch, "input_block_patch_after_skip")

135
136
137
    def set_model_output_block_patch(self, patch):
        self.set_model_patch(patch, "output_block_patch")

138
139
140
    def add_object_patch(self, name, obj):
        self.object_patches[name] = obj

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    def model_patches_to(self, device):
        to = self.model_options["transformer_options"]
        if "patches" in to:
            patches = to["patches"]
            for name in patches:
                patch_list = patches[name]
                for i in range(len(patch_list)):
                    if hasattr(patch_list[i], "to"):
                        patch_list[i] = patch_list[i].to(device)
        if "patches_replace" in to:
            patches = to["patches_replace"]
            for name in patches:
                patch_list = patches[name]
                for k in patch_list:
                    if hasattr(patch_list[k], "to"):
                        patch_list[k] = patch_list[k].to(device)
157
158
        if "model_function_wrapper" in self.model_options:
            wrap_func = self.model_options["model_function_wrapper"]
159
            if hasattr(wrap_func, "to"):
160
                self.model_options["model_function_wrapper"] = wrap_func.to(device)
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    def model_dtype(self):
        if hasattr(self.model, "get_dtype"):
            return self.model.get_dtype()

    def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
        p = set()
        for k in patches:
            if k in self.model_keys:
                p.add(k)
                current_patches = self.patches.get(k, [])
                current_patches.append((strength_patch, patches[k], strength_model))
                self.patches[k] = current_patches

175
        self.patches_uuid = uuid.uuid4()
176
177
178
        return list(p)

    def get_key_patches(self, filter_prefix=None):
comfyanonymous's avatar
comfyanonymous committed
179
        comfy.model_management.unload_model_clones(self)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        model_sd = self.model_state_dict()
        p = {}
        for k in model_sd:
            if filter_prefix is not None:
                if not k.startswith(filter_prefix):
                    continue
            if k in self.patches:
                p[k] = [model_sd[k]] + self.patches[k]
            else:
                p[k] = (model_sd[k],)
        return p

    def model_state_dict(self, filter_prefix=None):
        sd = self.model.state_dict()
        keys = list(sd.keys())
        if filter_prefix is not None:
            for k in keys:
                if not k.startswith(filter_prefix):
                    sd.pop(k)
        return sd

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    def patch_weight_to_device(self, key, device_to=None):
        if key not in self.patches:
            return

        weight = comfy.utils.get_attr(self.model, key)

        inplace_update = self.weight_inplace_update

        if key not in self.backup:
            self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)

        if device_to is not None:
            temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
        else:
            temp_weight = weight.to(torch.float32, copy=True)
        out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
        if inplace_update:
            comfy.utils.copy_to_param(self.model, key, out_weight)
        else:
            comfy.utils.set_attr_param(self.model, key, out_weight)

222
    def patch_model(self, device_to=None, patch_weights=True):
223
        for k in self.object_patches:
224
            old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
225
226
227
            if k not in self.object_patches_backup:
                self.object_patches_backup[k] = old

228
229
230
231
        if patch_weights:
            model_sd = self.model_state_dict()
            for key in self.patches:
                if key not in model_sd:
232
                    logging.warning("could not patch. key doesn't exist in model: {}".format(key))
233
                    continue
234

235
                self.patch_weight_to_device(key, device_to)
236

237
238
239
            if device_to is not None:
                self.model.to(device_to)
                self.current_device = device_to
240
241
242

        return self.model

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
        self.patch_model(device_to, patch_weights=False)

        logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
        class LowVramPatch:
            def __init__(self, key, model_patcher):
                self.key = key
                self.model_patcher = model_patcher
            def __call__(self, weight):
                return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)

        mem_counter = 0
        for n, m in self.model.named_modules():
            lowvram_weight = False
            if hasattr(m, "comfy_cast_weights"):
                module_mem = comfy.model_management.module_size(m)
                if mem_counter + module_mem >= lowvram_model_memory:
                    lowvram_weight = True

            weight_key = "{}.weight".format(n)
            bias_key = "{}.bias".format(n)

            if lowvram_weight:
                if weight_key in self.patches:
                    m.weight_function = LowVramPatch(weight_key, self)
                if bias_key in self.patches:
                    m.bias_function = LowVramPatch(weight_key, self)

                m.prev_comfy_cast_weights = m.comfy_cast_weights
                m.comfy_cast_weights = True
            else:
                if hasattr(m, "weight"):
                    self.patch_weight_to_device(weight_key, device_to)
                    self.patch_weight_to_device(bias_key, device_to)
                    m.to(device_to)
                    mem_counter += comfy.model_management.module_size(m)
                    logging.debug("lowvram: loaded module regularly {}".format(m))

        self.model_lowvram = True
        return self.model

284
285
286
287
288
289
290
291
292
293
294
295
296
    def calculate_weight(self, patches, weight, key):
        for p in patches:
            alpha = p[0]
            v = p[1]
            strength_model = p[2]

            if strength_model != 1.0:
                weight *= strength_model

            if isinstance(v, list):
                v = (self.calculate_weight(v[1:], v[0].clone(), key), )

            if len(v) == 1:
comfyanonymous's avatar
comfyanonymous committed
297
298
299
300
301
302
                patch_type = "diff"
            elif len(v) == 2:
                patch_type = v[0]
                v = v[1]

            if patch_type == "diff":
303
304
305
                w1 = v[0]
                if alpha != 0.0:
                    if w1.shape != weight.shape:
306
                        logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
307
                    else:
308
                        weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
comfyanonymous's avatar
comfyanonymous committed
309
            elif patch_type == "lora": #lora/locon
310
311
                mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
                mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
312
313
314
315
                if v[2] is not None:
                    alpha *= v[2] / mat2.shape[0]
                if v[3] is not None:
                    #locon mid weights, hopefully the math is fine because I didn't properly test it
316
                    mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
317
318
319
320
321
                    final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
                    mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
                try:
                    weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
322
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
323
            elif patch_type == "lokr":
324
325
326
327
328
329
330
331
332
333
334
                w1 = v[0]
                w2 = v[1]
                w1_a = v[3]
                w1_b = v[4]
                w2_a = v[5]
                w2_b = v[6]
                t2 = v[7]
                dim = None

                if w1 is None:
                    dim = w1_b.shape[0]
335
336
                    w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
337
                else:
338
                    w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
339
340
341
342

                if w2 is None:
                    dim = w2_b.shape[0]
                    if t2 is None:
343
344
                        w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
345
                    else:
346
347
348
349
                        w2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                          comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
350
                else:
351
                    w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
352
353
354
355
356
357
358
359
360

                if len(w2.shape) == 4:
                    w1 = w1.unsqueeze(2).unsqueeze(2)
                if v[2] is not None and dim is not None:
                    alpha *= v[2] / dim

                try:
                    weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
361
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
362
            elif patch_type == "loha":
363
364
365
366
367
368
369
370
371
                w1a = v[0]
                w1b = v[1]
                if v[2] is not None:
                    alpha *= v[2] / w1b.shape[0]
                w2a = v[3]
                w2b = v[4]
                if v[5] is not None: #cp decomposition
                    t1 = v[5]
                    t2 = v[6]
372
373
374
375
376
377
378
379
380
                    m1 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))

                    m2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
381
                else:
382
383
384
385
                    m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
                    m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
386
387
388
389

                try:
                    weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
390
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
391
392
393
394
395
396
397
398
399
            elif patch_type == "glora":
                if v[4] is not None:
                    alpha *= v[4] / v[0].shape[0]

                a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
                a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
                b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
                b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

400
401
402
403
                try:
                    weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
404
            else:
405
                logging.warning("patch type not recognized {} {}".format(patch_type, key))
406
407
408

        return weight

409
410
411
412
413
414
415
416
417
    def unpatch_model(self, device_to=None, unpatch_weights=True):
        if unpatch_weights:
            if self.model_lowvram:
                for m in self.model.modules():
                    if hasattr(m, "prev_comfy_cast_weights"):
                        m.comfy_cast_weights = m.prev_comfy_cast_weights
                        del m.prev_comfy_cast_weights
                    m.weight_function = None
                    m.bias_function = None
418

419
                self.model_lowvram = False
420

421
            keys = list(self.backup.keys())
422

423
424
425
426
427
428
            if self.weight_inplace_update:
                for k in keys:
                    comfy.utils.copy_to_param(self.model, k, self.backup[k])
            else:
                for k in keys:
                    comfy.utils.set_attr_param(self.model, k, self.backup[k])
429

430
            self.backup.clear()
431

432
433
434
            if device_to is not None:
                self.model.to(device_to)
                self.current_device = device_to
435
436
437

        keys = list(self.object_patches_backup.keys())
        for k in keys:
438
            comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
439
440

        self.object_patches_backup = {}