model_patcher.py 12.8 KB
Newer Older
1
2
3
4
5
import torch
import copy
import inspect

import comfy.utils
6
import comfy.model_management
7
8
9
10
11
12
13

class ModelPatcher:
    def __init__(self, model, load_device, offload_device, size=0, current_device=None):
        self.size = size
        self.model = model
        self.patches = {}
        self.backup = {}
14
15
        self.object_patches = {}
        self.object_patches_backup = {}
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        self.model_options = {"transformer_options":{}}
        self.model_size()
        self.load_device = load_device
        self.offload_device = offload_device
        if current_device is None:
            self.current_device = self.offload_device
        else:
            self.current_device = current_device

    def model_size(self):
        if self.size > 0:
            return self.size
        model_sd = self.model.state_dict()
        size = 0
        for k in model_sd:
            t = model_sd[k]
            size += t.nelement() * t.element_size()
        self.size = size
        self.model_keys = set(model_sd.keys())
        return size

    def clone(self):
        n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
        n.patches = {}
        for k in self.patches:
            n.patches[k] = self.patches[k][:]

43
        n.object_patches = self.object_patches.copy()
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        n.model_options = copy.deepcopy(self.model_options)
        n.model_keys = self.model_keys
        return n

    def is_clone(self, other):
        if hasattr(other, 'model') and self.model is other.model:
            return True
        return False

    def set_model_sampler_cfg_function(self, sampler_cfg_function):
        if len(inspect.signature(sampler_cfg_function).parameters) == 3:
            self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
        else:
            self.model_options["sampler_cfg_function"] = sampler_cfg_function

    def set_model_unet_function_wrapper(self, unet_wrapper_function):
        self.model_options["model_function_wrapper"] = unet_wrapper_function

    def set_model_patch(self, patch, name):
        to = self.model_options["transformer_options"]
        if "patches" not in to:
            to["patches"] = {}
        to["patches"][name] = to["patches"].get(name, []) + [patch]

    def set_model_patch_replace(self, patch, name, block_name, number):
        to = self.model_options["transformer_options"]
        if "patches_replace" not in to:
            to["patches_replace"] = {}
        if name not in to["patches_replace"]:
            to["patches_replace"][name] = {}
        to["patches_replace"][name][(block_name, number)] = patch

    def set_model_attn1_patch(self, patch):
        self.set_model_patch(patch, "attn1_patch")

    def set_model_attn2_patch(self, patch):
        self.set_model_patch(patch, "attn2_patch")

    def set_model_attn1_replace(self, patch, block_name, number):
        self.set_model_patch_replace(patch, "attn1", block_name, number)

    def set_model_attn2_replace(self, patch, block_name, number):
        self.set_model_patch_replace(patch, "attn2", block_name, number)

    def set_model_attn1_output_patch(self, patch):
        self.set_model_patch(patch, "attn1_output_patch")

    def set_model_attn2_output_patch(self, patch):
        self.set_model_patch(patch, "attn2_output_patch")

94
95
96
    def set_model_output_block_patch(self, patch):
        self.set_model_patch(patch, "output_block_patch")

97
98
99
    def add_object_patch(self, name, obj):
        self.object_patches[name] = obj

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def model_patches_to(self, device):
        to = self.model_options["transformer_options"]
        if "patches" in to:
            patches = to["patches"]
            for name in patches:
                patch_list = patches[name]
                for i in range(len(patch_list)):
                    if hasattr(patch_list[i], "to"):
                        patch_list[i] = patch_list[i].to(device)
        if "patches_replace" in to:
            patches = to["patches_replace"]
            for name in patches:
                patch_list = patches[name]
                for k in patch_list:
                    if hasattr(patch_list[k], "to"):
                        patch_list[k] = patch_list[k].to(device)
116
117
        if "model_function_wrapper" in self.model_options:
            wrap_func = self.model_options["model_function_wrapper"]
118
            if hasattr(wrap_func, "to"):
119
                self.model_options["model_function_wrapper"] = wrap_func.to(device)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

    def model_dtype(self):
        if hasattr(self.model, "get_dtype"):
            return self.model.get_dtype()

    def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
        p = set()
        for k in patches:
            if k in self.model_keys:
                p.add(k)
                current_patches = self.patches.get(k, [])
                current_patches.append((strength_patch, patches[k], strength_model))
                self.patches[k] = current_patches

        return list(p)

    def get_key_patches(self, filter_prefix=None):
        model_sd = self.model_state_dict()
        p = {}
        for k in model_sd:
            if filter_prefix is not None:
                if not k.startswith(filter_prefix):
                    continue
            if k in self.patches:
                p[k] = [model_sd[k]] + self.patches[k]
            else:
                p[k] = (model_sd[k],)
        return p

    def model_state_dict(self, filter_prefix=None):
        sd = self.model.state_dict()
        keys = list(sd.keys())
        if filter_prefix is not None:
            for k in keys:
                if not k.startswith(filter_prefix):
                    sd.pop(k)
        return sd

    def patch_model(self, device_to=None):
159
160
161
162
163
164
        for k in self.object_patches:
            old = getattr(self.model, k)
            if k not in self.object_patches_backup:
                self.object_patches_backup[k] = old
            setattr(self.model, k, self.object_patches[k])

165
166
167
        model_sd = self.model_state_dict()
        for key in self.patches:
            if key not in model_sd:
168
                print("could not patch. key doesn't exist in model:", key)
169
170
171
172
173
174
175
176
                continue

            weight = model_sd[key]

            if key not in self.backup:
                self.backup[key] = weight.to(self.offload_device)

            if device_to is not None:
177
                temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            else:
                temp_weight = weight.to(torch.float32, copy=True)
            out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
            comfy.utils.set_attr(self.model, key, out_weight)
            del temp_weight

        if device_to is not None:
            self.model.to(device_to)
            self.current_device = device_to

        return self.model

    def calculate_weight(self, patches, weight, key):
        for p in patches:
            alpha = p[0]
            v = p[1]
            strength_model = p[2]

            if strength_model != 1.0:
                weight *= strength_model

            if isinstance(v, list):
                v = (self.calculate_weight(v[1:], v[0].clone(), key), )

            if len(v) == 1:
                w1 = v[0]
                if alpha != 0.0:
                    if w1.shape != weight.shape:
                        print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
                    else:
208
                        weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
209
            elif len(v) == 4: #lora/locon
210
211
                mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
                mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
212
213
214
215
                if v[2] is not None:
                    alpha *= v[2] / mat2.shape[0]
                if v[3] is not None:
                    #locon mid weights, hopefully the math is fine because I didn't properly test it
216
                    mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
                    final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
                    mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
                try:
                    weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
                    print("ERROR", key, e)
            elif len(v) == 8: #lokr
                w1 = v[0]
                w2 = v[1]
                w1_a = v[3]
                w1_b = v[4]
                w2_a = v[5]
                w2_b = v[6]
                t2 = v[7]
                dim = None

                if w1 is None:
                    dim = w1_b.shape[0]
235
236
                    w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
237
                else:
238
                    w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
239
240
241
242

                if w2 is None:
                    dim = w2_b.shape[0]
                    if t2 is None:
243
244
                        w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
245
                    else:
246
247
248
249
                        w2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                          comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
                                          comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
250
                else:
251
                    w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

                if len(w2.shape) == 4:
                    w1 = w1.unsqueeze(2).unsqueeze(2)
                if v[2] is not None and dim is not None:
                    alpha *= v[2] / dim

                try:
                    weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
                    print("ERROR", key, e)
            else: #loha
                w1a = v[0]
                w1b = v[1]
                if v[2] is not None:
                    alpha *= v[2] / w1b.shape[0]
                w2a = v[3]
                w2b = v[4]
                if v[5] is not None: #cp decomposition
                    t1 = v[5]
                    t2 = v[6]
272
273
274
275
276
277
278
279
280
                    m1 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))

                    m2 = torch.einsum('i j k l, j r, i p -> p r k l',
                                      comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
                                      comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
281
                else:
282
283
284
285
                    m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
                    m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
                                  comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304

                try:
                    weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
                    print("ERROR", key, e)

        return weight

    def unpatch_model(self, device_to=None):
        keys = list(self.backup.keys())

        for k in keys:
            comfy.utils.set_attr(self.model, k, self.backup[k])

        self.backup = {}

        if device_to is not None:
            self.model.to(device_to)
            self.current_device = device_to
305
306
307
308
309
310

        keys = list(self.object_patches_backup.keys())
        for k in keys:
            setattr(self.model, k, self.object_patches_backup[k])

        self.object_patches_backup = {}