controlnet.py 25.5 KB
Newer Older
1
2
import torch
import math
3
import os
4
import logging
5
6
7
import comfy.utils
import comfy.model_management
import comfy.model_detection
8
import comfy.model_patcher
9
import comfy.ops
10
import comfy.latent_formats
11
12
13

import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
14
import comfy.ldm.cascade.controlnet
15
import comfy.cldm.mmdit
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


def broadcast_image_to(tensor, target_batch_size, batched_number):
    current_batch_size = tensor.shape[0]
    #print(current_batch_size, target_batch_size)
    if current_batch_size == 1:
        return tensor

    per_batch = target_batch_size // batched_number
    tensor = tensor[:per_batch]

    if per_batch > tensor.shape[0]:
        tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)

    current_batch_size = tensor.shape[0]
    if current_batch_size == target_batch_size:
        return tensor
    else:
        return torch.cat([tensor] * batched_number, dim=0)

class ControlBase:
    def __init__(self, device=None):
        self.cond_hint_original = None
        self.cond_hint = None
        self.strength = 1.0
41
        self.timestep_percent_range = (0.0, 1.0)
42
43
        self.latent_format = None
        self.vae = None
44
        self.global_average_pooling = False
45
        self.timestep_range = None
46
        self.compression_ratio = 8
47
        self.upscale_algorithm = 'nearest-exact'
48
49
50
51
52
53

        if device is None:
            device = comfy.model_management.get_torch_device()
        self.device = device
        self.previous_controlnet = None

54
    def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
55
56
57
        self.cond_hint_original = cond_hint
        self.strength = strength
        self.timestep_percent_range = timestep_percent_range
58
59
        if self.latent_format is not None:
            self.vae = vae
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        return self

    def pre_run(self, model, percent_to_timestep_function):
        self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
        if self.previous_controlnet is not None:
            self.previous_controlnet.pre_run(model, percent_to_timestep_function)

    def set_previous_controlnet(self, controlnet):
        self.previous_controlnet = controlnet
        return self

    def cleanup(self):
        if self.previous_controlnet is not None:
            self.previous_controlnet.cleanup()
        if self.cond_hint is not None:
            del self.cond_hint
            self.cond_hint = None
        self.timestep_range = None

    def get_models(self):
        out = []
        if self.previous_controlnet is not None:
            out += self.previous_controlnet.get_models()
        return out

    def copy_to(self, c):
        c.cond_hint_original = self.cond_hint_original
        c.strength = self.strength
        c.timestep_percent_range = self.timestep_percent_range
89
        c.global_average_pooling = self.global_average_pooling
90
        c.compression_ratio = self.compression_ratio
91
        c.upscale_algorithm = self.upscale_algorithm
92
93
        c.latent_format = self.latent_format
        c.vae = self.vae
94
95
96
97
98
99

    def inference_memory_requirements(self, dtype):
        if self.previous_controlnet is not None:
            return self.previous_controlnet.inference_memory_requirements(dtype)
        return 0

comfyanonymous's avatar
comfyanonymous committed
100
    def control_merge(self, control, control_prev, output_dtype):
101
102
        out = {'input':[], 'middle':[], 'output': []}

comfyanonymous's avatar
comfyanonymous committed
103
104
        for key in control:
            control_output = control[key]
105
            applied_to = set()
106
107
108
109
110
111
            for i in range(len(control_output)):
                x = control_output[i]
                if x is not None:
                    if self.global_average_pooling:
                        x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])

112
113
114
115
                    if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
                        applied_to.add(x)
                        x *= self.strength

116
117
118
119
                    if x.dtype != output_dtype:
                        x = x.to(output_dtype)

                out[key].append(x)
comfyanonymous's avatar
comfyanonymous committed
120

121
122
123
124
125
126
127
128
129
130
131
        if control_prev is not None:
            for x in ['input', 'middle', 'output']:
                o = out[x]
                for i in range(len(control_prev[x])):
                    prev_val = control_prev[x][i]
                    if i >= len(o):
                        o.append(prev_val)
                    elif prev_val is not None:
                        if o[i] is None:
                            o[i] = prev_val
                        else:
132
133
134
                            if o[i].shape[0] < prev_val.shape[0]:
                                o[i] = prev_val + o[i]
                            else:
135
                                o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
136
137
138
        return out

class ControlNet(ControlBase):
139
    def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
140
141
        super().__init__(device)
        self.control_model = control_model
142
        self.load_device = load_device
143
144
145
        if control_model is not None:
            self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())

146
        self.compression_ratio = compression_ratio
147
        self.global_average_pooling = global_average_pooling
148
        self.model_sampling_current = None
149
        self.manual_cast_dtype = manual_cast_dtype
150
        self.latent_format = latent_format
151
152
153
154
155
156
157
158
159
160
161

    def get_control(self, x_noisy, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
comfyanonymous's avatar
comfyanonymous committed
162
                    return None
163

164
        dtype = self.control_model.dtype
165
166
        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype
167

168
        output_dtype = x_noisy.dtype
169
        if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
170
171
172
            if self.cond_hint is not None:
                del self.cond_hint
            self.cond_hint = None
173
174
175
176
177
178
179
180
181
182
183
            compression_ratio = self.compression_ratio
            if self.vae is not None:
                compression_ratio *= self.vae.downscale_ratio
            self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
            if self.vae is not None:
                loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
                self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
                comfy.model_management.load_models_gpu(loaded_models)
            if self.latent_format is not None:
                self.cond_hint = self.latent_format.process_in(self.cond_hint)
            self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
184
185
186
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

187
        context = cond.get('crossattn_controlnet', cond['c_crossattn'])
188
        y = cond.get('y', None)
189
        if y is not None:
190
            y = y.to(dtype)
191
192
193
        timestep = self.model_sampling_current.timestep(t)
        x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

194
        control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
comfyanonymous's avatar
comfyanonymous committed
195
        return self.control_merge(control, control_prev, output_dtype)
196
197

    def copy(self):
198
199
200
        c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
        c.control_model = self.control_model
        c.control_model_wrapped = self.control_model_wrapped
201
202
203
204
205
206
207
208
        self.copy_to(c)
        return c

    def get_models(self):
        out = super().get_models()
        out.append(self.control_model_wrapped)
        return out

209
210
211
212
213
214
215
216
    def pre_run(self, model, percent_to_timestep_function):
        super().pre_run(model, percent_to_timestep_function)
        self.model_sampling_current = model.model_sampling

    def cleanup(self):
        self.model_sampling_current = None
        super().cleanup()

217
class ControlLoraOps:
comfyanonymous's avatar
comfyanonymous committed
218
    class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
219
220
221
222
223
224
225
226
227
228
229
230
        def __init__(self, in_features: int, out_features: int, bias: bool = True,
                    device=None, dtype=None) -> None:
            factory_kwargs = {'device': device, 'dtype': dtype}
            super().__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = None
            self.up = None
            self.down = None
            self.bias = None

        def forward(self, input):
231
            weight, bias = comfy.ops.cast_bias_weight(self, input)
232
            if self.up is not None:
233
                return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
234
            else:
235
                return torch.nn.functional.linear(input, weight, bias)
236

comfyanonymous's avatar
comfyanonymous committed
237
    class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
            padding_mode='zeros',
            device=None,
            dtype=None
        ):
            super().__init__()
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.kernel_size = kernel_size
            self.stride = stride
            self.padding = padding
            self.dilation = dilation
            self.transposed = False
            self.output_padding = 0
            self.groups = groups
            self.padding_mode = padding_mode

            self.weight = None
            self.bias = None
            self.up = None
            self.down = None


        def forward(self, input):
271
            weight, bias = comfy.ops.cast_bias_weight(self, input)
272
            if self.up is not None:
273
                return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
274
            else:
275
                return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
276

277
278
279
280
281
282
283
284
285
286
287
288

class ControlLora(ControlNet):
    def __init__(self, control_weights, global_average_pooling=False, device=None):
        ControlBase.__init__(self, device)
        self.control_weights = control_weights
        self.global_average_pooling = global_average_pooling

    def pre_run(self, model, percent_to_timestep_function):
        super().pre_run(model, percent_to_timestep_function)
        controlnet_config = model.model_config.unet_config.copy()
        controlnet_config.pop("out_channels")
        controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
289
290
291
292
293
294
295
296
297
298
        self.manual_cast_dtype = model.manual_cast_dtype
        dtype = model.get_dtype()
        if self.manual_cast_dtype is None:
            class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
                pass
        else:
            class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
                pass
            dtype = self.manual_cast_dtype

comfyanonymous's avatar
comfyanonymous committed
299
        controlnet_config["operations"] = control_lora_ops
300
        controlnet_config["dtype"] = dtype
301
302
303
304
305
306
307
        self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
        self.control_model.to(comfy.model_management.get_torch_device())
        diffusion_model = model.diffusion_model
        sd = diffusion_model.state_dict()
        cm = self.control_model.state_dict()

        for k in sd:
308
            weight = sd[k]
309
            try:
310
                comfy.utils.set_attr_param(self.control_model, k, weight)
311
312
313
314
315
            except:
                pass

        for k in self.control_weights:
            if k not in {"lora_controlnet"}:
316
                comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

    def copy(self):
        c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
        self.copy_to(c)
        return c

    def cleanup(self):
        del self.control_model
        self.control_model = None
        super().cleanup()

    def get_models(self):
        out = ControlBase.get_models(self)
        return out

    def inference_memory_requirements(self, dtype):
        return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def load_controlnet_mmdit(sd):
    new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
    model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
    num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
    for k in sd:
        new_sd[k] = sd[k]

    supported_inference_dtypes = model_config.supported_inference_dtypes

    controlnet_config = model_config.unet_config
    unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
    load_device = comfy.model_management.get_torch_device()
    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
    if manual_cast_dtype is not None:
        operations = comfy.ops.manual_cast
    else:
        operations = comfy.ops.disable_weight_init

    control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
    missing, unexpected = control_model.load_state_dict(new_sd, strict=False)

    if len(missing) > 0:
        logging.warning("missing controlnet keys: {}".format(missing))

    if len(unexpected) > 0:
        logging.debug("unexpected controlnet keys: {}".format(unexpected))

362
363
364
    latent_format = comfy.latent_formats.SD3()
    latent_format.shift_factor = 0 #SD3 controlnet weirdness
    control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
365
366
367
    return control


368
369
370
371
372
373
def load_controlnet(ckpt_path, model=None):
    controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
    if "lora_controlnet" in controlnet_data:
        return ControlLora(controlnet_data)

    controlnet_config = None
comfyanonymous's avatar
comfyanonymous committed
374
375
    supported_inference_dtypes = None

376
    if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
comfyanonymous's avatar
comfyanonymous committed
377
        controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
        diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
        diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                k_in = "controlnet_down_blocks.{}{}".format(count, s)
                k_out = "zero_convs.{}.0{}".format(count, s)
                if k_in not in controlnet_data:
                    loop = False
                    break
                diffusers_keys[k_in] = k_out
            count += 1

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                if count == 0:
                    k_in = "controlnet_cond_embedding.conv_in{}".format(s)
                else:
                    k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
                k_out = "input_hint_block.{}{}".format(count * 2, s)
                if k_in not in controlnet_data:
                    k_in = "controlnet_cond_embedding.conv_out{}".format(s)
                    loop = False
                diffusers_keys[k_in] = k_out
            count += 1

        new_sd = {}
        for k in diffusers_keys:
            if k in controlnet_data:
                new_sd[diffusers_keys[k]] = controlnet_data.pop(k)

416
417
418
419
420
421
        if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
            controlnet_config["union_controlnet"] = True
            for k in list(controlnet_data.keys()):
                new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
                new_sd[new_k] = controlnet_data.pop(k)

422
423
        leftover_keys = controlnet_data.keys()
        if len(leftover_keys) > 0:
424
            logging.warning("leftover keys: {}".format(leftover_keys))
425
        controlnet_data = new_sd
426
427
    elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
        return load_controlnet_mmdit(controlnet_data)
428
429
430
431
432
433
434
435
436
437
438
439
440

    pth_key = 'control_model.zero_convs.0.0.weight'
    pth = False
    key = 'zero_convs.0.0.weight'
    if pth_key in controlnet_data:
        pth = True
        key = pth_key
        prefix = "control_model."
    elif key in controlnet_data:
        prefix = ""
    else:
        net = load_t2i_adapter(controlnet_data)
        if net is None:
441
            logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
442
443
444
        return net

    if controlnet_config is None:
comfyanonymous's avatar
comfyanonymous committed
445
446
447
448
        model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
        supported_inference_dtypes = model_config.supported_inference_dtypes
        controlnet_config = model_config.unet_config

449
    load_device = comfy.model_management.get_torch_device()
comfyanonymous's avatar
comfyanonymous committed
450
451
452
453
454
    if supported_inference_dtypes is None:
        unet_dtype = comfy.model_management.unet_dtype()
    else:
        unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)

455
456
457
    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
    if manual_cast_dtype is not None:
        controlnet_config["operations"] = comfy.ops.manual_cast
comfyanonymous's avatar
comfyanonymous committed
458
    controlnet_config["dtype"] = unet_dtype
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    controlnet_config.pop("out_channels")
    controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
    control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

    if pth:
        if 'difference' in controlnet_data:
            if model is not None:
                comfy.model_management.load_models_gpu([model])
                model_sd = model.model_state_dict()
                for x in controlnet_data:
                    c_m = "control_model."
                    if x.startswith(c_m):
                        sd_key = "diffusion_model.{}".format(x[len(c_m):])
                        if sd_key in model_sd:
                            cd = controlnet_data[x]
                            cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
            else:
476
                logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
477
478
479
480
481
482
483
484

        class WeightsLoader(torch.nn.Module):
            pass
        w = WeightsLoader()
        w.control_model = control_model
        missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
    else:
        missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
485
486
487
488
489

    if len(missing) > 0:
        logging.warning("missing controlnet keys: {}".format(missing))

    if len(unexpected) > 0:
comfyanonymous's avatar
comfyanonymous committed
490
        logging.debug("unexpected controlnet keys: {}".format(unexpected))
491
492

    global_average_pooling = False
493
494
    filename = os.path.splitext(ckpt_path)[0]
    if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
495
496
        global_average_pooling = True

497
    control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
498
499
500
    return control

class T2IAdapter(ControlBase):
501
    def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
502
503
504
505
        super().__init__(device)
        self.t2i_model = t2i_model
        self.channels_in = channels_in
        self.control_input = None
506
        self.compression_ratio = compression_ratio
507
        self.upscale_algorithm = upscale_algorithm
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524

    def scale_image_to(self, width, height):
        unshuffle_amount = self.t2i_model.unshuffle_amount
        width = math.ceil(width / unshuffle_amount) * unshuffle_amount
        height = math.ceil(height / unshuffle_amount) * unshuffle_amount
        return width, height

    def get_control(self, x_noisy, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
comfyanonymous's avatar
comfyanonymous committed
525
                    return None
526

527
        if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
528
529
530
531
            if self.cond_hint is not None:
                del self.cond_hint
            self.control_input = None
            self.cond_hint = None
532
            width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
533
            self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
534
535
536
537
538
539
540
541
542
543
            if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
                self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
        if self.control_input is None:
            self.t2i_model.to(x_noisy.dtype)
            self.t2i_model.to(self.device)
            self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
            self.t2i_model.cpu()

comfyanonymous's avatar
comfyanonymous committed
544
545
546
547
548
        control_input = {}
        for k in self.control_input:
            control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))

        return self.control_merge(control_input, control_prev, x_noisy.dtype)
549
550

    def copy(self):
551
        c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
552
553
554
555
        self.copy_to(c)
        return c

def load_t2i_adapter(t2i_data):
556
    compression_ratio = 8
557
    upscale_algorithm = 'nearest-exact'
558

559
    if 'adapter' in t2i_data:
560
        t2i_data = t2i_data['adapter']
561
562
563
564
565
566
567
568
569
570
    if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
        prefix_replace = {}
        for i in range(4):
            for j in range(2):
                prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
            prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
        prefix_replace["adapter."] = ""
        t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
    keys = t2i_data.keys()

571
572
573
574
575
576
577
578
579
580
581
582
    if "body.0.in_conv.weight" in keys:
        cin = t2i_data['body.0.in_conv.weight'].shape[1]
        model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
    elif 'conv_in.weight' in keys:
        cin = t2i_data['conv_in.weight'].shape[1]
        channel = t2i_data['conv_in.weight'].shape[0]
        ksize = t2i_data['body.0.block2.weight'].shape[2]
        use_conv = False
        down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
        if len(down_opts) > 0:
            use_conv = True
        xl = False
583
        if cin == 256 or cin == 768:
584
585
            xl = True
        model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
586
587
588
    elif "backbone.0.0.weight" in keys:
        model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
        compression_ratio = 32
589
        upscale_algorithm = 'bilinear'
590
591
592
593
    elif "backbone.10.blocks.0.weight" in keys:
        model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
        compression_ratio = 1
        upscale_algorithm = 'nearest-exact'
594
595
    else:
        return None
596

597
598
    missing, unexpected = model_ad.load_state_dict(t2i_data)
    if len(missing) > 0:
599
        logging.warning("t2i missing {}".format(missing))
600
601

    if len(unexpected) > 0:
comfyanonymous's avatar
comfyanonymous committed
602
        logging.debug("t2i unexpected {}".format(unexpected))
603

604
    return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)