controlnet.py 25.8 KB
Newer Older
1
2
import torch
import math
3
import os
4
import logging
5
6
7
import comfy.utils
import comfy.model_management
import comfy.model_detection
8
import comfy.model_patcher
9
import comfy.ops
10
import comfy.latent_formats
11
12
13

import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
14
import comfy.ldm.cascade.controlnet
15
import comfy.cldm.mmdit
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


def broadcast_image_to(tensor, target_batch_size, batched_number):
    current_batch_size = tensor.shape[0]
    #print(current_batch_size, target_batch_size)
    if current_batch_size == 1:
        return tensor

    per_batch = target_batch_size // batched_number
    tensor = tensor[:per_batch]

    if per_batch > tensor.shape[0]:
        tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)

    current_batch_size = tensor.shape[0]
    if current_batch_size == target_batch_size:
        return tensor
    else:
        return torch.cat([tensor] * batched_number, dim=0)

class ControlBase:
    def __init__(self, device=None):
        self.cond_hint_original = None
        self.cond_hint = None
        self.strength = 1.0
41
        self.timestep_percent_range = (0.0, 1.0)
42
43
        self.latent_format = None
        self.vae = None
44
        self.global_average_pooling = False
45
        self.timestep_range = None
46
        self.compression_ratio = 8
47
        self.upscale_algorithm = 'nearest-exact'
48
        self.extra_args = {}
49
50
51
52
53
54

        if device is None:
            device = comfy.model_management.get_torch_device()
        self.device = device
        self.previous_controlnet = None

55
    def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
56
57
58
        self.cond_hint_original = cond_hint
        self.strength = strength
        self.timestep_percent_range = timestep_percent_range
59
60
        if self.latent_format is not None:
            self.vae = vae
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        return self

    def pre_run(self, model, percent_to_timestep_function):
        self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
        if self.previous_controlnet is not None:
            self.previous_controlnet.pre_run(model, percent_to_timestep_function)

    def set_previous_controlnet(self, controlnet):
        self.previous_controlnet = controlnet
        return self

    def cleanup(self):
        if self.previous_controlnet is not None:
            self.previous_controlnet.cleanup()
        if self.cond_hint is not None:
            del self.cond_hint
            self.cond_hint = None
        self.timestep_range = None

    def get_models(self):
        out = []
        if self.previous_controlnet is not None:
            out += self.previous_controlnet.get_models()
        return out

    def copy_to(self, c):
        c.cond_hint_original = self.cond_hint_original
        c.strength = self.strength
        c.timestep_percent_range = self.timestep_percent_range
90
        c.global_average_pooling = self.global_average_pooling
91
        c.compression_ratio = self.compression_ratio
92
        c.upscale_algorithm = self.upscale_algorithm
93
        c.latent_format = self.latent_format
94
        c.extra_args = self.extra_args.copy()
95
        c.vae = self.vae
96
97
98
99
100
101

    def inference_memory_requirements(self, dtype):
        if self.previous_controlnet is not None:
            return self.previous_controlnet.inference_memory_requirements(dtype)
        return 0

comfyanonymous's avatar
comfyanonymous committed
102
    def control_merge(self, control, control_prev, output_dtype):
103
104
        out = {'input':[], 'middle':[], 'output': []}

comfyanonymous's avatar
comfyanonymous committed
105
106
        for key in control:
            control_output = control[key]
107
            applied_to = set()
108
109
110
111
112
113
            for i in range(len(control_output)):
                x = control_output[i]
                if x is not None:
                    if self.global_average_pooling:
                        x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])

114
115
116
117
                    if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
                        applied_to.add(x)
                        x *= self.strength

118
119
120
121
                    if x.dtype != output_dtype:
                        x = x.to(output_dtype)

                out[key].append(x)
comfyanonymous's avatar
comfyanonymous committed
122

123
124
125
126
127
128
129
130
131
132
133
        if control_prev is not None:
            for x in ['input', 'middle', 'output']:
                o = out[x]
                for i in range(len(control_prev[x])):
                    prev_val = control_prev[x][i]
                    if i >= len(o):
                        o.append(prev_val)
                    elif prev_val is not None:
                        if o[i] is None:
                            o[i] = prev_val
                        else:
134
135
136
                            if o[i].shape[0] < prev_val.shape[0]:
                                o[i] = prev_val + o[i]
                            else:
137
                                o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
138
139
        return out

140
141
142
143
    def set_extra_arg(self, argument, value=None):
        self.extra_args[argument] = value


144
class ControlNet(ControlBase):
145
    def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
146
147
        super().__init__(device)
        self.control_model = control_model
148
        self.load_device = load_device
149
150
151
        if control_model is not None:
            self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())

152
        self.compression_ratio = compression_ratio
153
        self.global_average_pooling = global_average_pooling
154
        self.model_sampling_current = None
155
        self.manual_cast_dtype = manual_cast_dtype
156
        self.latent_format = latent_format
157
158
159
160
161
162
163
164
165
166
167

    def get_control(self, x_noisy, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
comfyanonymous's avatar
comfyanonymous committed
168
                    return None
169

170
        dtype = self.control_model.dtype
171
172
        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype
173

174
        output_dtype = x_noisy.dtype
175
        if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
176
177
178
            if self.cond_hint is not None:
                del self.cond_hint
            self.cond_hint = None
179
180
181
182
183
184
185
186
187
188
189
            compression_ratio = self.compression_ratio
            if self.vae is not None:
                compression_ratio *= self.vae.downscale_ratio
            self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
            if self.vae is not None:
                loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
                self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
                comfy.model_management.load_models_gpu(loaded_models)
            if self.latent_format is not None:
                self.cond_hint = self.latent_format.process_in(self.cond_hint)
            self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
190
191
192
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

193
        context = cond.get('crossattn_controlnet', cond['c_crossattn'])
194
        y = cond.get('y', None)
195
        if y is not None:
196
            y = y.to(dtype)
197
198
199
        timestep = self.model_sampling_current.timestep(t)
        x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

200
        control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
comfyanonymous's avatar
comfyanonymous committed
201
        return self.control_merge(control, control_prev, output_dtype)
202
203

    def copy(self):
204
205
206
        c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
        c.control_model = self.control_model
        c.control_model_wrapped = self.control_model_wrapped
207
208
209
210
211
212
213
214
        self.copy_to(c)
        return c

    def get_models(self):
        out = super().get_models()
        out.append(self.control_model_wrapped)
        return out

215
216
217
218
219
220
221
222
    def pre_run(self, model, percent_to_timestep_function):
        super().pre_run(model, percent_to_timestep_function)
        self.model_sampling_current = model.model_sampling

    def cleanup(self):
        self.model_sampling_current = None
        super().cleanup()

223
class ControlLoraOps:
comfyanonymous's avatar
comfyanonymous committed
224
    class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
225
226
227
228
229
230
231
232
233
234
235
236
        def __init__(self, in_features: int, out_features: int, bias: bool = True,
                    device=None, dtype=None) -> None:
            factory_kwargs = {'device': device, 'dtype': dtype}
            super().__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = None
            self.up = None
            self.down = None
            self.bias = None

        def forward(self, input):
237
            weight, bias = comfy.ops.cast_bias_weight(self, input)
238
            if self.up is not None:
239
                return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
240
            else:
241
                return torch.nn.functional.linear(input, weight, bias)
242

comfyanonymous's avatar
comfyanonymous committed
243
    class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
            padding_mode='zeros',
            device=None,
            dtype=None
        ):
            super().__init__()
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.kernel_size = kernel_size
            self.stride = stride
            self.padding = padding
            self.dilation = dilation
            self.transposed = False
            self.output_padding = 0
            self.groups = groups
            self.padding_mode = padding_mode

            self.weight = None
            self.bias = None
            self.up = None
            self.down = None


        def forward(self, input):
277
            weight, bias = comfy.ops.cast_bias_weight(self, input)
278
            if self.up is not None:
279
                return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
280
            else:
281
                return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
282

283
284
285
286
287
288
289
290
291
292
293
294

class ControlLora(ControlNet):
    def __init__(self, control_weights, global_average_pooling=False, device=None):
        ControlBase.__init__(self, device)
        self.control_weights = control_weights
        self.global_average_pooling = global_average_pooling

    def pre_run(self, model, percent_to_timestep_function):
        super().pre_run(model, percent_to_timestep_function)
        controlnet_config = model.model_config.unet_config.copy()
        controlnet_config.pop("out_channels")
        controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
295
296
297
298
299
300
301
302
303
304
        self.manual_cast_dtype = model.manual_cast_dtype
        dtype = model.get_dtype()
        if self.manual_cast_dtype is None:
            class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
                pass
        else:
            class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
                pass
            dtype = self.manual_cast_dtype

comfyanonymous's avatar
comfyanonymous committed
305
        controlnet_config["operations"] = control_lora_ops
306
        controlnet_config["dtype"] = dtype
307
308
309
310
311
312
313
        self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
        self.control_model.to(comfy.model_management.get_torch_device())
        diffusion_model = model.diffusion_model
        sd = diffusion_model.state_dict()
        cm = self.control_model.state_dict()

        for k in sd:
314
            weight = sd[k]
315
            try:
316
                comfy.utils.set_attr_param(self.control_model, k, weight)
317
318
319
320
321
            except:
                pass

        for k in self.control_weights:
            if k not in {"lora_controlnet"}:
322
                comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

    def copy(self):
        c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
        self.copy_to(c)
        return c

    def cleanup(self):
        del self.control_model
        self.control_model = None
        super().cleanup()

    def get_models(self):
        out = ControlBase.get_models(self)
        return out

    def inference_memory_requirements(self, dtype):
        return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def load_controlnet_mmdit(sd):
    new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
    model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
    num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
    for k in sd:
        new_sd[k] = sd[k]

    supported_inference_dtypes = model_config.supported_inference_dtypes

    controlnet_config = model_config.unet_config
    unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
    load_device = comfy.model_management.get_torch_device()
    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
    if manual_cast_dtype is not None:
        operations = comfy.ops.manual_cast
    else:
        operations = comfy.ops.disable_weight_init

    control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
    missing, unexpected = control_model.load_state_dict(new_sd, strict=False)

    if len(missing) > 0:
        logging.warning("missing controlnet keys: {}".format(missing))

    if len(unexpected) > 0:
        logging.debug("unexpected controlnet keys: {}".format(unexpected))

368
369
370
    latent_format = comfy.latent_formats.SD3()
    latent_format.shift_factor = 0 #SD3 controlnet weirdness
    control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
371
372
373
    return control


374
375
376
377
378
379
def load_controlnet(ckpt_path, model=None):
    controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
    if "lora_controlnet" in controlnet_data:
        return ControlLora(controlnet_data)

    controlnet_config = None
comfyanonymous's avatar
comfyanonymous committed
380
381
    supported_inference_dtypes = None

382
    if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
comfyanonymous's avatar
comfyanonymous committed
383
        controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
        diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
        diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                k_in = "controlnet_down_blocks.{}{}".format(count, s)
                k_out = "zero_convs.{}.0{}".format(count, s)
                if k_in not in controlnet_data:
                    loop = False
                    break
                diffusers_keys[k_in] = k_out
            count += 1

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                if count == 0:
                    k_in = "controlnet_cond_embedding.conv_in{}".format(s)
                else:
                    k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
                k_out = "input_hint_block.{}{}".format(count * 2, s)
                if k_in not in controlnet_data:
                    k_in = "controlnet_cond_embedding.conv_out{}".format(s)
                    loop = False
                diffusers_keys[k_in] = k_out
            count += 1

        new_sd = {}
        for k in diffusers_keys:
            if k in controlnet_data:
                new_sd[diffusers_keys[k]] = controlnet_data.pop(k)

422
        if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
423
            controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
424
425
426
427
            for k in list(controlnet_data.keys()):
                new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
                new_sd[new_k] = controlnet_data.pop(k)

428
429
        leftover_keys = controlnet_data.keys()
        if len(leftover_keys) > 0:
430
            logging.warning("leftover keys: {}".format(leftover_keys))
431
        controlnet_data = new_sd
432
433
    elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
        return load_controlnet_mmdit(controlnet_data)
434
435
436
437
438
439
440
441
442
443
444
445
446

    pth_key = 'control_model.zero_convs.0.0.weight'
    pth = False
    key = 'zero_convs.0.0.weight'
    if pth_key in controlnet_data:
        pth = True
        key = pth_key
        prefix = "control_model."
    elif key in controlnet_data:
        prefix = ""
    else:
        net = load_t2i_adapter(controlnet_data)
        if net is None:
447
            logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
448
449
450
        return net

    if controlnet_config is None:
comfyanonymous's avatar
comfyanonymous committed
451
452
453
454
        model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
        supported_inference_dtypes = model_config.supported_inference_dtypes
        controlnet_config = model_config.unet_config

455
    load_device = comfy.model_management.get_torch_device()
comfyanonymous's avatar
comfyanonymous committed
456
457
458
459
460
    if supported_inference_dtypes is None:
        unet_dtype = comfy.model_management.unet_dtype()
    else:
        unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)

461
462
463
    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
    if manual_cast_dtype is not None:
        controlnet_config["operations"] = comfy.ops.manual_cast
comfyanonymous's avatar
comfyanonymous committed
464
    controlnet_config["dtype"] = unet_dtype
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    controlnet_config.pop("out_channels")
    controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
    control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

    if pth:
        if 'difference' in controlnet_data:
            if model is not None:
                comfy.model_management.load_models_gpu([model])
                model_sd = model.model_state_dict()
                for x in controlnet_data:
                    c_m = "control_model."
                    if x.startswith(c_m):
                        sd_key = "diffusion_model.{}".format(x[len(c_m):])
                        if sd_key in model_sd:
                            cd = controlnet_data[x]
                            cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
            else:
482
                logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
483
484
485
486
487
488
489
490

        class WeightsLoader(torch.nn.Module):
            pass
        w = WeightsLoader()
        w.control_model = control_model
        missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
    else:
        missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
491
492
493
494
495

    if len(missing) > 0:
        logging.warning("missing controlnet keys: {}".format(missing))

    if len(unexpected) > 0:
comfyanonymous's avatar
comfyanonymous committed
496
        logging.debug("unexpected controlnet keys: {}".format(unexpected))
497
498

    global_average_pooling = False
499
500
    filename = os.path.splitext(ckpt_path)[0]
    if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
501
502
        global_average_pooling = True

503
    control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
504
505
506
    return control

class T2IAdapter(ControlBase):
507
    def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
508
509
510
511
        super().__init__(device)
        self.t2i_model = t2i_model
        self.channels_in = channels_in
        self.control_input = None
512
        self.compression_ratio = compression_ratio
513
        self.upscale_algorithm = upscale_algorithm
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

    def scale_image_to(self, width, height):
        unshuffle_amount = self.t2i_model.unshuffle_amount
        width = math.ceil(width / unshuffle_amount) * unshuffle_amount
        height = math.ceil(height / unshuffle_amount) * unshuffle_amount
        return width, height

    def get_control(self, x_noisy, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
comfyanonymous's avatar
comfyanonymous committed
531
                    return None
532

533
        if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
534
535
536
537
            if self.cond_hint is not None:
                del self.cond_hint
            self.control_input = None
            self.cond_hint = None
538
            width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
539
            self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
540
541
542
543
544
545
546
547
548
549
            if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
                self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
        if self.control_input is None:
            self.t2i_model.to(x_noisy.dtype)
            self.t2i_model.to(self.device)
            self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
            self.t2i_model.cpu()

comfyanonymous's avatar
comfyanonymous committed
550
551
552
553
554
        control_input = {}
        for k in self.control_input:
            control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))

        return self.control_merge(control_input, control_prev, x_noisy.dtype)
555
556

    def copy(self):
557
        c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
558
559
560
561
        self.copy_to(c)
        return c

def load_t2i_adapter(t2i_data):
562
    compression_ratio = 8
563
    upscale_algorithm = 'nearest-exact'
564

565
    if 'adapter' in t2i_data:
566
        t2i_data = t2i_data['adapter']
567
568
569
570
571
572
573
574
575
576
    if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
        prefix_replace = {}
        for i in range(4):
            for j in range(2):
                prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
            prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
        prefix_replace["adapter."] = ""
        t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
    keys = t2i_data.keys()

577
578
579
580
581
582
583
584
585
586
587
588
    if "body.0.in_conv.weight" in keys:
        cin = t2i_data['body.0.in_conv.weight'].shape[1]
        model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
    elif 'conv_in.weight' in keys:
        cin = t2i_data['conv_in.weight'].shape[1]
        channel = t2i_data['conv_in.weight'].shape[0]
        ksize = t2i_data['body.0.block2.weight'].shape[2]
        use_conv = False
        down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
        if len(down_opts) > 0:
            use_conv = True
        xl = False
589
        if cin == 256 or cin == 768:
590
591
            xl = True
        model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
592
593
594
    elif "backbone.0.0.weight" in keys:
        model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
        compression_ratio = 32
595
        upscale_algorithm = 'bilinear'
596
597
598
599
    elif "backbone.10.blocks.0.weight" in keys:
        model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
        compression_ratio = 1
        upscale_algorithm = 'nearest-exact'
600
601
    else:
        return None
602

603
604
    missing, unexpected = model_ad.load_state_dict(t2i_data)
    if len(missing) > 0:
605
        logging.warning("t2i missing {}".format(missing))
606
607

    if len(unexpected) > 0:
comfyanonymous's avatar
comfyanonymous committed
608
        logging.debug("t2i unexpected {}".format(unexpected))
609

610
    return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)