cldm.py 18.6 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
#taken from: https://github.com/lllyasviel/ControlNet
#and modified

import torch
import torch as th
import torch.nn as nn

comfyanonymous's avatar
comfyanonymous committed
8
from ..ldm.modules.diffusionmodules.util import (
comfyanonymous's avatar
comfyanonymous committed
9
10
11
12
    zero_module,
    timestep_embedding,
)

comfyanonymous's avatar
comfyanonymous committed
13
from ..ldm.modules.attention import SpatialTransformer
comfyanonymous's avatar
comfyanonymous committed
14
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
from ..ldm.util import exists
16
17
from ..ldm.cascade.common import OptimizedAttention
from collections import OrderedDict
comfyanonymous's avatar
comfyanonymous committed
18
import comfy.ops
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from comfy.ldm.modules.attention import optimized_attention

class OptimizedAttention(nn.Module):
    def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
        super().__init__()
        self.heads = nhead
        self.c = c

        self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
        self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)

    def forward(self, x):
        x = self.in_proj(x)
        q, k, v = x.split(self.c, dim=2)
        out = optimized_attention(q, k, v, self.heads)
        return self.out_proj(out)

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class ResBlockUnionControlnet(nn.Module):
    def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
        super().__init__()
        self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
        self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
        self.mlp = nn.Sequential(
            OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
                         ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
        self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)

    def attention(self, x: torch.Tensor):
        return self.attn(x)

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
comfyanonymous's avatar
comfyanonymous committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

class ControlledUnetModel(UNetModel):
    #implemented in the ldm unet
    pass

class ControlNet(nn.Module):
    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        hint_channels,
        num_res_blocks,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
74
        num_classes=None,
comfyanonymous's avatar
comfyanonymous committed
75
        use_checkpoint=False,
76
        dtype=torch.float32,
comfyanonymous's avatar
comfyanonymous committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
        disable_self_attentions=None,
        num_attention_blocks=None,
        disable_middle_self_attn=False,
        use_linear_in_transformer=False,
92
93
        adm_in_channels=None,
        transformer_depth_middle=None,
94
        transformer_depth_output=None,
95
        attn_precision=None,
96
        union_controlnet=False,
comfyanonymous's avatar
comfyanonymous committed
97
        device=None,
comfyanonymous's avatar
comfyanonymous committed
98
        operations=comfy.ops.disable_weight_init,
comfyanonymous's avatar
comfyanonymous committed
99
        **kwargs,
comfyanonymous's avatar
comfyanonymous committed
100
101
    ):
        super().__init__()
comfyanonymous's avatar
comfyanonymous committed
102
        assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
comfyanonymous's avatar
comfyanonymous committed
103
104
105
106
107
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
108
109
110
            # from omegaconf.listconfig import ListConfig
            # if type(context_dim) == ListConfig:
            #     context_dim = list(context_dim)
comfyanonymous's avatar
comfyanonymous committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.dims = dims
        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
125

comfyanonymous's avatar
comfyanonymous committed
126
127
128
129
130
131
132
        if isinstance(num_res_blocks, int):
            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
        else:
            if len(num_res_blocks) != len(channel_mult):
                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
                                 "as a list/tuple (per-level) with the same length as channel_mult")
            self.num_res_blocks = num_res_blocks
133

comfyanonymous's avatar
comfyanonymous committed
134
135
136
137
138
139
140
        if disable_self_attentions is not None:
            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
            assert len(disable_self_attentions) == len(channel_mult)
        if num_attention_blocks is not None:
            assert len(num_attention_blocks) == len(self.num_res_blocks)
            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))

141
142
        transformer_depth = transformer_depth[:]

comfyanonymous's avatar
comfyanonymous committed
143
144
145
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
146
        self.num_classes = num_classes
comfyanonymous's avatar
comfyanonymous committed
147
        self.use_checkpoint = use_checkpoint
148
        self.dtype = dtype
comfyanonymous's avatar
comfyanonymous committed
149
150
151
152
153
154
155
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
comfyanonymous's avatar
comfyanonymous committed
156
            operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
157
            nn.SiLU(),
comfyanonymous's avatar
comfyanonymous committed
158
            operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
159
160
        )

161
162
163
164
165
166
167
168
169
170
        if self.num_classes is not None:
            if isinstance(self.num_classes, int):
                self.label_emb = nn.Embedding(num_classes, time_embed_dim)
            elif self.num_classes == "continuous":
                print("setting up linear c_adm embedding layer")
                self.label_emb = nn.Linear(1, time_embed_dim)
            elif self.num_classes == "sequential":
                assert adm_in_channels is not None
                self.label_emb = nn.Sequential(
                    nn.Sequential(
comfyanonymous's avatar
comfyanonymous committed
171
                        operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
172
                        nn.SiLU(),
comfyanonymous's avatar
comfyanonymous committed
173
                        operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
174
175
176
177
178
                    )
                )
            else:
                raise ValueError()

comfyanonymous's avatar
comfyanonymous committed
179
180
181
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
comfyanonymous's avatar
comfyanonymous committed
182
                    operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
183
184
185
                )
            ]
        )
186
        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
comfyanonymous's avatar
comfyanonymous committed
187
188

        self.input_hint_block = TimestepEmbedSequential(
189
                    operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
190
                    nn.SiLU(),
191
                    operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
192
                    nn.SiLU(),
193
                    operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
194
                    nn.SiLU(),
195
                    operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
196
                    nn.SiLU(),
197
                    operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
198
                    nn.SiLU(),
199
                    operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
200
                    nn.SiLU(),
201
                    operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
202
                    nn.SiLU(),
203
                    operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        )

        self._feature_size = model_channels
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for nr in range(self.num_res_blocks[level]):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
221
222
223
                        dtype=self.dtype,
                        device=device,
                        operations=operations,
comfyanonymous's avatar
comfyanonymous committed
224
225
226
                    )
                ]
                ch = mult * model_channels
227
228
                num_transformers = transformer_depth.pop(0)
                if num_transformers > 0:
comfyanonymous's avatar
comfyanonymous committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    if exists(disable_self_attentions):
                        disabled_sa = disable_self_attentions[level]
                    else:
                        disabled_sa = False

                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
                        layers.append(
comfyanonymous's avatar
comfyanonymous committed
244
                            SpatialTransformer(
245
                                ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
comfyanonymous's avatar
comfyanonymous committed
246
                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
247
                                use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
comfyanonymous's avatar
comfyanonymous committed
248
249
250
                            )
                        )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
251
                self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
comfyanonymous's avatar
comfyanonymous committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
267
268
                            dtype=self.dtype,
                            device=device,
comfyanonymous's avatar
comfyanonymous committed
269
                            operations=operations
comfyanonymous's avatar
comfyanonymous committed
270
271
272
                        )
                        if resblock_updown
                        else Downsample(
273
                            ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
comfyanonymous's avatar
comfyanonymous committed
274
275
276
277
278
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
279
                self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
comfyanonymous's avatar
comfyanonymous committed
280
281
282
283
284
285
286
287
288
289
290
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            #num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
291
        mid_block = [
comfyanonymous's avatar
comfyanonymous committed
292
293
294
295
296
297
298
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
299
300
                dtype=self.dtype,
                device=device,
comfyanonymous's avatar
comfyanonymous committed
301
                operations=operations
302
303
304
            )]
        if transformer_depth_middle >= 0:
            mid_block += [SpatialTransformer(  # always uses a self-attn
305
                            ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
comfyanonymous's avatar
comfyanonymous committed
306
                            disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
307
                            use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
comfyanonymous's avatar
comfyanonymous committed
308
309
310
311
312
313
314
315
                        ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
316
317
                dtype=self.dtype,
                device=device,
comfyanonymous's avatar
comfyanonymous committed
318
                operations=operations
319
320
            )]
        self.middle_block = TimestepEmbedSequential(*mid_block)
321
        self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
322
323
        self._feature_size += ch

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        if union_controlnet:
            self.num_control_type = 6
            num_trans_channel = 320
            num_trans_head = 8
            num_trans_layer = 1
            num_proj_channel = 320
            # task_scale_factor = num_trans_channel ** 0.5
            self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))

            self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
            self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
            #-----------------------------------------------------------------------------------------------------

            control_add_embed_dim = 256
            class ControlAddEmbedding(nn.Module):
                def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
                    super().__init__()
                    self.num_control_type = num_control_type
                    self.in_dim = in_dim
                    self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
                    self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
                def forward(self, control_type, dtype, device):
                    c_type = torch.zeros((self.num_control_type,), device=device)
                    c_type[control_type] = 1.0
                    c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
                    return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))

            self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
        else:
            self.task_embedding = None
            self.control_add_embedding = None

    def union_controlnet_merge(self, hint, control_type, emb, context):
        # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
        inputs = []
        condition_list = []

        for idx in range(min(1, len(control_type))):
            controlnet_cond = self.input_hint_block(hint[idx], emb, context)
            feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
            if idx < len(control_type):
                feat_seq += self.task_embedding[control_type[idx]]

            inputs.append(feat_seq.unsqueeze(1))
            condition_list.append(controlnet_cond)

        x = torch.cat(inputs, dim=1)
        x = self.transformer_layes(x)
        controlnet_cond_fuser = None
        for idx in range(len(control_type)):
            alpha = self.spatial_ch_projs(x[:, idx])
            alpha = alpha.unsqueeze(-1).unsqueeze(-1)
            o = condition_list[idx] + alpha
            if controlnet_cond_fuser is None:
                controlnet_cond_fuser = o
            else:
                controlnet_cond_fuser += o
        return controlnet_cond_fuser

383
384
    def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
        return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
comfyanonymous's avatar
comfyanonymous committed
385

386
    def forward(self, x, hint, timesteps, context, y=None, **kwargs):
387
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
comfyanonymous's avatar
comfyanonymous committed
388
389
        emb = self.time_embed(t_emb)

390
391
392
393
394
395
396
397
398
399
400
401
        guided_hint = None
        if self.control_add_embedding is not None:
            control_type = kwargs.get("control_type", [])

            emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
            if len(control_type) > 0:
                if len(hint.shape) < 5:
                    hint = hint.unsqueeze(dim=0)
                guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)

        if guided_hint is None:
            guided_hint = self.input_hint_block(hint, emb, context)
comfyanonymous's avatar
comfyanonymous committed
402

comfyanonymous's avatar
comfyanonymous committed
403
404
        out_output = []
        out_middle = []
comfyanonymous's avatar
comfyanonymous committed
405

406
407
408
409
410
        hs = []
        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

411
        h = x
comfyanonymous's avatar
comfyanonymous committed
412
413
414
415
416
417
418
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            if guided_hint is not None:
                h = module(h, emb, context)
                h += guided_hint
                guided_hint = None
            else:
                h = module(h, emb, context)
comfyanonymous's avatar
comfyanonymous committed
419
            out_output.append(zero_conv(h, emb, context))
comfyanonymous's avatar
comfyanonymous committed
420
421

        h = self.middle_block(h, emb, context)
comfyanonymous's avatar
comfyanonymous committed
422
        out_middle.append(self.middle_block_out(h, emb, context))
comfyanonymous's avatar
comfyanonymous committed
423

comfyanonymous's avatar
comfyanonymous committed
424
        return {"middle": out_middle, "output": out_output}
comfyanonymous's avatar
comfyanonymous committed
425