model_patcher.py 18.1 KB
Newer Older
1
2
3
import torch
import copy
import inspect
4
import logging
5
6

import comfy.utils
7
import comfy.model_management
8
9

class ModelPatcher:
10
    def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
11
12
13
14
        self.size = size
        self.model = model
        self.patches = {}
        self.backup = {}
15
16
        self.object_patches = {}
        self.object_patches_backup = {}
17
18
19
20
21
22
23
24
25
        self.model_options = {"transformer_options":{}}
        self.model_size()
        self.load_device = load_device
        self.offload_device = offload_device
        if current_device is None:
            self.current_device = self.offload_device
        else:
            self.current_device = current_device

26
        self.weight_inplace_update = weight_inplace_update
27
        self.model_lowvram = False
28

29
30
31
32
    def model_size(self):
        if self.size > 0:
            return self.size
        model_sd = self.model.state_dict()
33
        self.size = comfy.model_management.module_size(self.model)
34
        self.model_keys = set(model_sd.keys())
35
        return self.size
36
37

    def clone(self):
comfyanonymous's avatar
comfyanonymous committed
38
        n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
39
40
41
42
        n.patches = {}
        for k in self.patches:
            n.patches[k] = self.patches[k][:]

43
        n.object_patches = self.object_patches.copy()
44
45
46
47
48
49
50
51
52
        n.model_options = copy.deepcopy(self.model_options)
        n.model_keys = self.model_keys
        return n

    def is_clone(self, other):
        if hasattr(other, 'model') and self.model is other.model:
            return True
        return False

53
54
55
    def memory_required(self, input_shape):
        return self.model.memory_required(input_shape=input_shape)

56
    def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
57
58
59
60
        if len(inspect.signature(sampler_cfg_function).parameters) == 3:
            self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
        else:
            self.model_options["sampler_cfg_function"] = sampler_cfg_function
61
62
        if disable_cfg1_optimization:
            self.model_options["disable_cfg1_optimization"] = True
63

64
    def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
65
        self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
66
67
        if disable_cfg1_optimization:
            self.model_options["disable_cfg1_optimization"] = True
68

69
70
71
    def set_model_unet_function_wrapper(self, unet_wrapper_function):
        self.model_options["model_function_wrapper"] = unet_wrapper_function

72
73
74
    def set_model_denoise_mask_function(self, denoise_mask_function):
        self.model_options["denoise_mask_function"] = denoise_mask_function

75
76
77
78
79
80
    def set_model_patch(self, patch, name):
        to = self.model_options["transformer_options"]
        if "patches" not in to:
            to["patches"] = {}
        to["patches"][name] = to["patches"].get(name, []) + [patch]

81
    def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
82
83
84
85
86
        to = self.model_options["transformer_options"]
        if "patches_replace" not in to:
            to["patches_replace"] = {}
        if name not in to["patches_replace"]:
            to["patches_replace"][name] = {}
87
88
89
90
91
        if transformer_index is not None:
            block = (block_name, number, transformer_index)
        else:
            block = (block_name, number)
        to["patches_replace"][name][block] = patch
92
93
94
95
96
97
98

    def set_model_attn1_patch(self, patch):
        self.set_model_patch(patch, "attn1_patch")

    def set_model_attn2_patch(self, patch):
        self.set_model_patch(patch, "attn2_patch")

99
100
    def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
        self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
101

102
103
    def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
        self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
104
105
106
107
108
109
110

    def set_model_attn1_output_patch(self, patch):
        self.set_model_patch(patch, "attn1_output_patch")

    def set_model_attn2_output_patch(self, patch):
        self.set_model_patch(patch, "attn2_output_patch")

111
112
113
    def set_model_input_block_patch(self, patch):
        self.set_model_patch(patch, "input_block_patch")

114
115
116
    def set_model_input_block_patch_after_skip(self, patch):
        self.set_model_patch(patch, "input_block_patch_after_skip")

117
118
119
    def set_model_output_block_patch(self, patch):
        self.set_model_patch(patch, "output_block_patch")

120
121
122
    def add_object_patch(self, name, obj):
        self.object_patches[name] = obj

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    def model_patches_to(self, device):
        to = self.model_options["transformer_options"]
        if "patches" in to:
            patches = to["patches"]
            for name in patches:
                patch_list = patches[name]
                for i in range(len(patch_list)):
                    if hasattr(patch_list[i], "to"):
                        patch_list[i] = patch_list[i].to(device)
        if "patches_replace" in to:
            patches = to["patches_replace"]
            for name in patches:
                patch_list = patches[name]
                for k in patch_list:
                    if hasattr(patch_list[k], "to"):
                        patch_list[k] = patch_list[k].to(device)
139
140
        if "model_function_wrapper" in self.model_options:
            wrap_func = self.model_options["model_function_wrapper"]
141
            if hasattr(wrap_func, "to"):
142
                self.model_options["model_function_wrapper"] = wrap_func.to(device)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

    def model_dtype(self):
        if hasattr(self.model, "get_dtype"):
            return self.model.get_dtype()

    def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
        p = set()
        for k in patches:
            if k in self.model_keys:
                p.add(k)
                current_patches = self.patches.get(k, [])
                current_patches.append((strength_patch, patches[k], strength_model))
                self.patches[k] = current_patches

        return list(p)

    def get_key_patches(self, filter_prefix=None):
comfyanonymous's avatar
comfyanonymous committed
160
        comfy.model_management.unload_model_clones(self)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        model_sd = self.model_state_dict()
        p = {}
        for k in model_sd:
            if filter_prefix is not None:
                if not k.startswith(filter_prefix):
                    continue
            if k in self.patches:
                p[k] = [model_sd[k]] + self.patches[k]
            else:
                p[k] = (model_sd[k],)
        return p

    def model_state_dict(self, filter_prefix=None):
        sd = self.model.state_dict()
        keys = list(sd.keys())
        if filter_prefix is not None:
            for k in keys:
                if not k.startswith(filter_prefix):
                    sd.pop(k)
        return sd

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def patch_weight_to_device(self, key, device_to=None):
        if key not in self.patches:
            return

        weight = comfy.utils.get_attr(self.model, key)

        inplace_update = self.weight_inplace_update

        if key not in self.backup:
            self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)

        if device_to is not None:
            temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
        else:
            temp_weight = weight.to(torch.float32, copy=True)
        out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
        if inplace_update:
            comfy.utils.copy_to_param(self.model, key, out_weight)
        else:
            comfy.utils.set_attr_param(self.model, key, out_weight)

203
    def patch_model(self, device_to=None, patch_weights=True):
204
        for k in self.object_patches:
205
            old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
206
207
208
            if k not in self.object_patches_backup:
                self.object_patches_backup[k] = old

209
210
211
212
        if patch_weights:
            model_sd = self.model_state_dict()
            for key in self.patches:
                if key not in model_sd:
213
                    logging.warning("could not patch. key doesn't exist in model: {}".format(key))
214
                    continue
215

216
                self.patch_weight_to_device(key, device_to)
217

218
219
220
            if device_to is not None:
                self.model.to(device_to)
                self.current_device = device_to
221
222
223

        return self.model

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
        self.patch_model(device_to, patch_weights=False)

        logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
        class LowVramPatch:
            def __init__(self, key, model_patcher):
                self.key = key
                self.model_patcher = model_patcher
            def __call__(self, weight):
                return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)

        mem_counter = 0
        for n, m in self.model.named_modules():
            lowvram_weight = False
            if hasattr(m, "comfy_cast_weights"):
                module_mem = comfy.model_management.module_size(m)
                if mem_counter + module_mem >= lowvram_model_memory:
                    lowvram_weight = True

            weight_key = "{}.weight".format(n)
            bias_key = "{}.bias".format(n)

            if lowvram_weight:
                if weight_key in self.patches:
                    m.weight_function = LowVramPatch(weight_key, self)
                if bias_key in self.patches:
                    m.bias_function = LowVramPatch(weight_key, self)

                m.prev_comfy_cast_weights = m.comfy_cast_weights
                m.comfy_cast_weights = True
            else:
                if hasattr(m, "weight"):
                    self.patch_weight_to_device(weight_key, device_to)
                    self.patch_weight_to_device(bias_key, device_to)
                    m.to(device_to)
                    mem_counter += comfy.model_management.module_size(m)
                    logging.debug("lowvram: loaded module regularly {}".format(m))

        self.model_lowvram = True
        return self.model

265
266
267
268
269
270
271
272
273
274
275
276
277
    def calculate_weight(self, patches, weight, key):
        for p in patches:
            alpha = p[0]
            v = p[1]
            strength_model = p[2]

            if strength_model != 1.0:
                weight *= strength_model

            if isinstance(v, list):
                v = (self.calculate_weight(v[1:], v[0].clone(), key), )

            if len(v) == 1:
comfyanonymous's avatar
comfyanonymous committed
278
279
280
281
282
283
                patch_type = "diff"
            elif len(v) == 2:
                patch_type = v[0]
                v = v[1]

            if patch_type == "diff":
284
285
286
                w1 = v[0]
                if alpha != 0.0:
                    if w1.shape != weight.shape:
287
                        logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
288
                    else:
289
                        weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
comfyanonymous's avatar
comfyanonymous committed
290
            elif patch_type == "lora": #lora/locon
291
292
                mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
                mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
293
294
295
296
                if v[2] is not None:
                    alpha *= v[2] / mat2.shape[0]
                if v[3] is not None:
                    #locon mid weights, hopefully the math is fine because I didn't properly test it
297
                    mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
298
299
300
301
302
                    final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
                    mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
                try:
                    weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
303
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
304
            elif patch_type == "lokr":
305
306
307
308
309
310
311
312
313
314
315
                w1 = v[0]
                w2 = v[1]
                w1_a = v[3]
                w1_b = v[4]
                w2_a = v[5]
                w2_b = v[6]
                t2 = v[7]
                dim = None

                if w1 is None:
                    dim = w1_b.shape[0]
316
317
                    w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
318
                else:
319
                    w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
320
321
322
323

                if w2 is None:
                    dim = w2_b.shape[0]
                    if t2 is None:
324
325
                        w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
326
                    else:
327
328
329
330
                        w2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                          comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
331
                else:
332
                    w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
333
334
335
336
337
338
339
340
341

                if len(w2.shape) == 4:
                    w1 = w1.unsqueeze(2).unsqueeze(2)
                if v[2] is not None and dim is not None:
                    alpha *= v[2] / dim

                try:
                    weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
342
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
343
            elif patch_type == "loha":
344
345
346
347
348
349
350
351
352
                w1a = v[0]
                w1b = v[1]
                if v[2] is not None:
                    alpha *= v[2] / w1b.shape[0]
                w2a = v[3]
                w2b = v[4]
                if v[5] is not None: #cp decomposition
                    t1 = v[5]
                    t2 = v[6]
353
354
355
356
357
358
359
360
361
                    m1 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))

                    m2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
362
                else:
363
364
365
366
                    m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
                    m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
367
368
369
370

                try:
                    weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
371
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
372
373
374
375
376
377
378
379
380
            elif patch_type == "glora":
                if v[4] is not None:
                    alpha *= v[4] / v[0].shape[0]

                a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
                a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
                b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
                b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

381
382
383
384
                try:
                    weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
                    logging.error("ERROR {} {} {}".format(patch_type, key, e))
comfyanonymous's avatar
comfyanonymous committed
385
            else:
386
                logging.warning("patch type not recognized {} {}".format(patch_type, key))
387
388
389
390

        return weight

    def unpatch_model(self, device_to=None):
391
392
393
394
395
396
397
398
399
400
        if self.model_lowvram:
            for m in self.model.modules():
                if hasattr(m, "prev_comfy_cast_weights"):
                    m.comfy_cast_weights = m.prev_comfy_cast_weights
                    del m.prev_comfy_cast_weights
                m.weight_function = None
                m.bias_function = None

            self.model_lowvram = False

401
402
        keys = list(self.backup.keys())

403
404
405
406
407
        if self.weight_inplace_update:
            for k in keys:
                comfy.utils.copy_to_param(self.model, k, self.backup[k])
        else:
            for k in keys:
408
                comfy.utils.set_attr_param(self.model, k, self.backup[k])
409
410
411
412
413
414

        self.backup = {}

        if device_to is not None:
            self.model.to(device_to)
            self.current_device = device_to
415
416
417

        keys = list(self.object_patches_backup.keys())
        for k in keys:
418
            comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
419
420

        self.object_patches_backup = {}