controlnet.py 22.3 KB
Newer Older
1
2
import torch
import math
3
import os
4
import logging
5
6
7
import comfy.utils
import comfy.model_management
import comfy.model_detection
8
import comfy.model_patcher
9
import comfy.ops
10
11
12

import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
13
import comfy.ldm.cascade.controlnet
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


def broadcast_image_to(tensor, target_batch_size, batched_number):
    current_batch_size = tensor.shape[0]
    #print(current_batch_size, target_batch_size)
    if current_batch_size == 1:
        return tensor

    per_batch = target_batch_size // batched_number
    tensor = tensor[:per_batch]

    if per_batch > tensor.shape[0]:
        tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)

    current_batch_size = tensor.shape[0]
    if current_batch_size == target_batch_size:
        return tensor
    else:
        return torch.cat([tensor] * batched_number, dim=0)

class ControlBase:
    def __init__(self, device=None):
        self.cond_hint_original = None
        self.cond_hint = None
        self.strength = 1.0
39
        self.timestep_percent_range = (0.0, 1.0)
40
        self.global_average_pooling = False
41
        self.timestep_range = None
42
        self.compression_ratio = 8
43
        self.upscale_algorithm = 'nearest-exact'
44
45
46
47
48
49

        if device is None:
            device = comfy.model_management.get_torch_device()
        self.device = device
        self.previous_controlnet = None

50
    def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        self.cond_hint_original = cond_hint
        self.strength = strength
        self.timestep_percent_range = timestep_percent_range
        return self

    def pre_run(self, model, percent_to_timestep_function):
        self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
        if self.previous_controlnet is not None:
            self.previous_controlnet.pre_run(model, percent_to_timestep_function)

    def set_previous_controlnet(self, controlnet):
        self.previous_controlnet = controlnet
        return self

    def cleanup(self):
        if self.previous_controlnet is not None:
            self.previous_controlnet.cleanup()
        if self.cond_hint is not None:
            del self.cond_hint
            self.cond_hint = None
        self.timestep_range = None

    def get_models(self):
        out = []
        if self.previous_controlnet is not None:
            out += self.previous_controlnet.get_models()
        return out

    def copy_to(self, c):
        c.cond_hint_original = self.cond_hint_original
        c.strength = self.strength
        c.timestep_percent_range = self.timestep_percent_range
83
        c.global_average_pooling = self.global_average_pooling
84
        c.compression_ratio = self.compression_ratio
85
        c.upscale_algorithm = self.upscale_algorithm
86
87
88
89
90
91

    def inference_memory_requirements(self, dtype):
        if self.previous_controlnet is not None:
            return self.previous_controlnet.inference_memory_requirements(dtype)
        return 0

comfyanonymous's avatar
comfyanonymous committed
92
    def control_merge(self, control, control_prev, output_dtype):
93
94
        out = {'input':[], 'middle':[], 'output': []}

comfyanonymous's avatar
comfyanonymous committed
95
96
        for key in control:
            control_output = control[key]
97
98
99
100
101
102
103
104
105
106
107
            for i in range(len(control_output)):
                x = control_output[i]
                if x is not None:
                    if self.global_average_pooling:
                        x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])

                    x *= self.strength
                    if x.dtype != output_dtype:
                        x = x.to(output_dtype)

                out[key].append(x)
comfyanonymous's avatar
comfyanonymous committed
108

109
110
111
112
113
114
115
116
117
118
119
        if control_prev is not None:
            for x in ['input', 'middle', 'output']:
                o = out[x]
                for i in range(len(control_prev[x])):
                    prev_val = control_prev[x][i]
                    if i >= len(o):
                        o.append(prev_val)
                    elif prev_val is not None:
                        if o[i] is None:
                            o[i] = prev_val
                        else:
120
121
122
123
                            if o[i].shape[0] < prev_val.shape[0]:
                                o[i] = prev_val + o[i]
                            else:
                                o[i] += prev_val
124
125
126
        return out

class ControlNet(ControlBase):
127
    def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
128
129
        super().__init__(device)
        self.control_model = control_model
130
        self.load_device = load_device
131
132
133
        if control_model is not None:
            self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())

134
        self.global_average_pooling = global_average_pooling
135
        self.model_sampling_current = None
136
        self.manual_cast_dtype = manual_cast_dtype
137
138
139
140
141
142
143
144
145
146
147

    def get_control(self, x_noisy, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
comfyanonymous's avatar
comfyanonymous committed
148
                    return None
149

150
        dtype = self.control_model.dtype
151
152
        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype
153

154
        output_dtype = x_noisy.dtype
155
        if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
156
157
158
            if self.cond_hint is not None:
                del self.cond_hint
            self.cond_hint = None
159
            self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device)
160
161
162
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

163
        context = cond.get('crossattn_controlnet', cond['c_crossattn'])
164
        y = cond.get('y', None)
165
        if y is not None:
166
            y = y.to(dtype)
167
168
169
        timestep = self.model_sampling_current.timestep(t)
        x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

170
        control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
comfyanonymous's avatar
comfyanonymous committed
171
        return self.control_merge(control, control_prev, output_dtype)
172
173

    def copy(self):
174
175
176
        c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
        c.control_model = self.control_model
        c.control_model_wrapped = self.control_model_wrapped
177
178
179
180
181
182
183
184
        self.copy_to(c)
        return c

    def get_models(self):
        out = super().get_models()
        out.append(self.control_model_wrapped)
        return out

185
186
187
188
189
190
191
192
    def pre_run(self, model, percent_to_timestep_function):
        super().pre_run(model, percent_to_timestep_function)
        self.model_sampling_current = model.model_sampling

    def cleanup(self):
        self.model_sampling_current = None
        super().cleanup()

193
class ControlLoraOps:
comfyanonymous's avatar
comfyanonymous committed
194
    class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
195
196
197
198
199
200
201
202
203
204
205
206
        def __init__(self, in_features: int, out_features: int, bias: bool = True,
                    device=None, dtype=None) -> None:
            factory_kwargs = {'device': device, 'dtype': dtype}
            super().__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = None
            self.up = None
            self.down = None
            self.bias = None

        def forward(self, input):
207
            weight, bias = comfy.ops.cast_bias_weight(self, input)
208
            if self.up is not None:
209
                return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
210
            else:
211
                return torch.nn.functional.linear(input, weight, bias)
212

comfyanonymous's avatar
comfyanonymous committed
213
    class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
            padding_mode='zeros',
            device=None,
            dtype=None
        ):
            super().__init__()
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.kernel_size = kernel_size
            self.stride = stride
            self.padding = padding
            self.dilation = dilation
            self.transposed = False
            self.output_padding = 0
            self.groups = groups
            self.padding_mode = padding_mode

            self.weight = None
            self.bias = None
            self.up = None
            self.down = None


        def forward(self, input):
247
            weight, bias = comfy.ops.cast_bias_weight(self, input)
248
            if self.up is not None:
249
                return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
250
            else:
251
                return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
252

253
254
255
256
257
258
259
260
261
262
263
264

class ControlLora(ControlNet):
    def __init__(self, control_weights, global_average_pooling=False, device=None):
        ControlBase.__init__(self, device)
        self.control_weights = control_weights
        self.global_average_pooling = global_average_pooling

    def pre_run(self, model, percent_to_timestep_function):
        super().pre_run(model, percent_to_timestep_function)
        controlnet_config = model.model_config.unet_config.copy()
        controlnet_config.pop("out_channels")
        controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
265
266
267
268
269
270
271
272
273
274
        self.manual_cast_dtype = model.manual_cast_dtype
        dtype = model.get_dtype()
        if self.manual_cast_dtype is None:
            class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
                pass
        else:
            class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
                pass
            dtype = self.manual_cast_dtype

comfyanonymous's avatar
comfyanonymous committed
275
        controlnet_config["operations"] = control_lora_ops
276
        controlnet_config["dtype"] = dtype
277
278
279
280
281
282
283
        self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
        self.control_model.to(comfy.model_management.get_torch_device())
        diffusion_model = model.diffusion_model
        sd = diffusion_model.state_dict()
        cm = self.control_model.state_dict()

        for k in sd:
284
            weight = sd[k]
285
            try:
286
                comfy.utils.set_attr_param(self.control_model, k, weight)
287
288
289
290
291
            except:
                pass

        for k in self.control_weights:
            if k not in {"lora_controlnet"}:
292
                comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

    def copy(self):
        c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
        self.copy_to(c)
        return c

    def cleanup(self):
        del self.control_model
        self.control_model = None
        super().cleanup()

    def get_models(self):
        out = ControlBase.get_models(self)
        return out

    def inference_memory_requirements(self, dtype):
        return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)

def load_controlnet(ckpt_path, model=None):
    controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
    if "lora_controlnet" in controlnet_data:
        return ControlLora(controlnet_data)

    controlnet_config = None
comfyanonymous's avatar
comfyanonymous committed
317
318
    supported_inference_dtypes = None

319
    if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
comfyanonymous's avatar
comfyanonymous committed
320
        controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
        diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
        diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                k_in = "controlnet_down_blocks.{}{}".format(count, s)
                k_out = "zero_convs.{}.0{}".format(count, s)
                if k_in not in controlnet_data:
                    loop = False
                    break
                diffusers_keys[k_in] = k_out
            count += 1

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                if count == 0:
                    k_in = "controlnet_cond_embedding.conv_in{}".format(s)
                else:
                    k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
                k_out = "input_hint_block.{}{}".format(count * 2, s)
                if k_in not in controlnet_data:
                    k_in = "controlnet_cond_embedding.conv_out{}".format(s)
                    loop = False
                diffusers_keys[k_in] = k_out
            count += 1

        new_sd = {}
        for k in diffusers_keys:
            if k in controlnet_data:
                new_sd[diffusers_keys[k]] = controlnet_data.pop(k)

        leftover_keys = controlnet_data.keys()
        if len(leftover_keys) > 0:
361
            logging.warning("leftover keys: {}".format(leftover_keys))
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        controlnet_data = new_sd

    pth_key = 'control_model.zero_convs.0.0.weight'
    pth = False
    key = 'zero_convs.0.0.weight'
    if pth_key in controlnet_data:
        pth = True
        key = pth_key
        prefix = "control_model."
    elif key in controlnet_data:
        prefix = ""
    else:
        net = load_t2i_adapter(controlnet_data)
        if net is None:
376
            logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
377
378
379
        return net

    if controlnet_config is None:
comfyanonymous's avatar
comfyanonymous committed
380
381
382
383
        model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
        supported_inference_dtypes = model_config.supported_inference_dtypes
        controlnet_config = model_config.unet_config

384
    load_device = comfy.model_management.get_torch_device()
comfyanonymous's avatar
comfyanonymous committed
385
386
387
388
389
    if supported_inference_dtypes is None:
        unet_dtype = comfy.model_management.unet_dtype()
    else:
        unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)

390
391
392
    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
    if manual_cast_dtype is not None:
        controlnet_config["operations"] = comfy.ops.manual_cast
comfyanonymous's avatar
comfyanonymous committed
393
    controlnet_config["dtype"] = unet_dtype
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    controlnet_config.pop("out_channels")
    controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
    control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

    if pth:
        if 'difference' in controlnet_data:
            if model is not None:
                comfy.model_management.load_models_gpu([model])
                model_sd = model.model_state_dict()
                for x in controlnet_data:
                    c_m = "control_model."
                    if x.startswith(c_m):
                        sd_key = "diffusion_model.{}".format(x[len(c_m):])
                        if sd_key in model_sd:
                            cd = controlnet_data[x]
                            cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
            else:
411
                logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
412
413
414
415
416
417
418
419

        class WeightsLoader(torch.nn.Module):
            pass
        w = WeightsLoader()
        w.control_model = control_model
        missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
    else:
        missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
420
421
422
423
424

    if len(missing) > 0:
        logging.warning("missing controlnet keys: {}".format(missing))

    if len(unexpected) > 0:
comfyanonymous's avatar
comfyanonymous committed
425
        logging.debug("unexpected controlnet keys: {}".format(unexpected))
426
427

    global_average_pooling = False
428
429
    filename = os.path.splitext(ckpt_path)[0]
    if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
430
431
        global_average_pooling = True

432
    control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
433
434
435
    return control

class T2IAdapter(ControlBase):
436
    def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
437
438
439
440
        super().__init__(device)
        self.t2i_model = t2i_model
        self.channels_in = channels_in
        self.control_input = None
441
        self.compression_ratio = compression_ratio
442
        self.upscale_algorithm = upscale_algorithm
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459

    def scale_image_to(self, width, height):
        unshuffle_amount = self.t2i_model.unshuffle_amount
        width = math.ceil(width / unshuffle_amount) * unshuffle_amount
        height = math.ceil(height / unshuffle_amount) * unshuffle_amount
        return width, height

    def get_control(self, x_noisy, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
comfyanonymous's avatar
comfyanonymous committed
460
                    return None
461

462
        if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
463
464
465
466
            if self.cond_hint is not None:
                del self.cond_hint
            self.control_input = None
            self.cond_hint = None
467
            width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
468
            self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
469
470
471
472
473
474
475
476
477
478
            if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
                self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
        if self.control_input is None:
            self.t2i_model.to(x_noisy.dtype)
            self.t2i_model.to(self.device)
            self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
            self.t2i_model.cpu()

comfyanonymous's avatar
comfyanonymous committed
479
480
481
482
483
        control_input = {}
        for k in self.control_input:
            control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))

        return self.control_merge(control_input, control_prev, x_noisy.dtype)
484
485

    def copy(self):
486
        c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
487
488
489
490
        self.copy_to(c)
        return c

def load_t2i_adapter(t2i_data):
491
    compression_ratio = 8
492
    upscale_algorithm = 'nearest-exact'
493

494
    if 'adapter' in t2i_data:
495
        t2i_data = t2i_data['adapter']
496
497
498
499
500
501
502
503
504
505
    if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
        prefix_replace = {}
        for i in range(4):
            for j in range(2):
                prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
            prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
        prefix_replace["adapter."] = ""
        t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
    keys = t2i_data.keys()

506
507
508
509
510
511
512
513
514
515
516
517
    if "body.0.in_conv.weight" in keys:
        cin = t2i_data['body.0.in_conv.weight'].shape[1]
        model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
    elif 'conv_in.weight' in keys:
        cin = t2i_data['conv_in.weight'].shape[1]
        channel = t2i_data['conv_in.weight'].shape[0]
        ksize = t2i_data['body.0.block2.weight'].shape[2]
        use_conv = False
        down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
        if len(down_opts) > 0:
            use_conv = True
        xl = False
518
        if cin == 256 or cin == 768:
519
520
            xl = True
        model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
521
522
523
    elif "backbone.0.0.weight" in keys:
        model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
        compression_ratio = 32
524
        upscale_algorithm = 'bilinear'
525
526
527
528
    elif "backbone.10.blocks.0.weight" in keys:
        model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
        compression_ratio = 1
        upscale_algorithm = 'nearest-exact'
529
530
    else:
        return None
531

532
533
    missing, unexpected = model_ad.load_state_dict(t2i_data)
    if len(missing) > 0:
534
        logging.warning("t2i missing {}".format(missing))
535
536

    if len(unexpected) > 0:
comfyanonymous's avatar
comfyanonymous committed
537
        logging.debug("t2i unexpected {}".format(unexpected))
538

539
    return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)