cldm.py 18.7 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
#taken from: https://github.com/lllyasviel/ControlNet
#and modified

import torch
import torch as th
import torch.nn as nn

comfyanonymous's avatar
comfyanonymous committed
8
from ..ldm.modules.diffusionmodules.util import (
comfyanonymous's avatar
comfyanonymous committed
9
10
11
12
    zero_module,
    timestep_embedding,
)

comfyanonymous's avatar
comfyanonymous committed
13
from ..ldm.modules.attention import SpatialTransformer
comfyanonymous's avatar
comfyanonymous committed
14
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
from ..ldm.util import exists
16
from collections import OrderedDict
comfyanonymous's avatar
comfyanonymous committed
17
import comfy.ops
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from comfy.ldm.modules.attention import optimized_attention

class OptimizedAttention(nn.Module):
    def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
        super().__init__()
        self.heads = nhead
        self.c = c

        self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
        self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)

    def forward(self, x):
        x = self.in_proj(x)
        q, k, v = x.split(self.c, dim=2)
        out = optimized_attention(q, k, v, self.heads)
        return self.out_proj(out)

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class ResBlockUnionControlnet(nn.Module):
    def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
        super().__init__()
        self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
        self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
        self.mlp = nn.Sequential(
            OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
                         ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
        self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)

    def attention(self, x: torch.Tensor):
        return self.attn(x)

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
comfyanonymous's avatar
comfyanonymous committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

class ControlledUnetModel(UNetModel):
    #implemented in the ldm unet
    pass

class ControlNet(nn.Module):
    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        hint_channels,
        num_res_blocks,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
73
        num_classes=None,
comfyanonymous's avatar
comfyanonymous committed
74
        use_checkpoint=False,
75
        dtype=torch.float32,
comfyanonymous's avatar
comfyanonymous committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
        disable_self_attentions=None,
        num_attention_blocks=None,
        disable_middle_self_attn=False,
        use_linear_in_transformer=False,
91
92
        adm_in_channels=None,
        transformer_depth_middle=None,
93
        transformer_depth_output=None,
94
        attn_precision=None,
95
        union_controlnet_num_control_type=None,
comfyanonymous's avatar
comfyanonymous committed
96
        device=None,
comfyanonymous's avatar
comfyanonymous committed
97
        operations=comfy.ops.disable_weight_init,
comfyanonymous's avatar
comfyanonymous committed
98
        **kwargs,
comfyanonymous's avatar
comfyanonymous committed
99
100
    ):
        super().__init__()
comfyanonymous's avatar
comfyanonymous committed
101
        assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
comfyanonymous's avatar
comfyanonymous committed
102
103
104
105
106
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
107
108
109
            # from omegaconf.listconfig import ListConfig
            # if type(context_dim) == ListConfig:
            #     context_dim = list(context_dim)
comfyanonymous's avatar
comfyanonymous committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.dims = dims
        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
124

comfyanonymous's avatar
comfyanonymous committed
125
126
127
128
129
130
131
        if isinstance(num_res_blocks, int):
            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
        else:
            if len(num_res_blocks) != len(channel_mult):
                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
                                 "as a list/tuple (per-level) with the same length as channel_mult")
            self.num_res_blocks = num_res_blocks
132

comfyanonymous's avatar
comfyanonymous committed
133
134
135
136
137
138
139
        if disable_self_attentions is not None:
            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
            assert len(disable_self_attentions) == len(channel_mult)
        if num_attention_blocks is not None:
            assert len(num_attention_blocks) == len(self.num_res_blocks)
            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))

140
141
        transformer_depth = transformer_depth[:]

comfyanonymous's avatar
comfyanonymous committed
142
143
144
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
145
        self.num_classes = num_classes
comfyanonymous's avatar
comfyanonymous committed
146
        self.use_checkpoint = use_checkpoint
147
        self.dtype = dtype
comfyanonymous's avatar
comfyanonymous committed
148
149
150
151
152
153
154
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
comfyanonymous's avatar
comfyanonymous committed
155
            operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
156
            nn.SiLU(),
comfyanonymous's avatar
comfyanonymous committed
157
            operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
158
159
        )

160
161
162
163
164
165
166
167
168
169
        if self.num_classes is not None:
            if isinstance(self.num_classes, int):
                self.label_emb = nn.Embedding(num_classes, time_embed_dim)
            elif self.num_classes == "continuous":
                print("setting up linear c_adm embedding layer")
                self.label_emb = nn.Linear(1, time_embed_dim)
            elif self.num_classes == "sequential":
                assert adm_in_channels is not None
                self.label_emb = nn.Sequential(
                    nn.Sequential(
comfyanonymous's avatar
comfyanonymous committed
170
                        operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
171
                        nn.SiLU(),
comfyanonymous's avatar
comfyanonymous committed
172
                        operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
173
174
175
176
177
                    )
                )
            else:
                raise ValueError()

comfyanonymous's avatar
comfyanonymous committed
178
179
180
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
comfyanonymous's avatar
comfyanonymous committed
181
                    operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
182
183
184
                )
            ]
        )
185
        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
comfyanonymous's avatar
comfyanonymous committed
186
187

        self.input_hint_block = TimestepEmbedSequential(
188
                    operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
189
                    nn.SiLU(),
190
                    operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
191
                    nn.SiLU(),
192
                    operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
193
                    nn.SiLU(),
194
                    operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
195
                    nn.SiLU(),
196
                    operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
197
                    nn.SiLU(),
198
                    operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
199
                    nn.SiLU(),
200
                    operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
201
                    nn.SiLU(),
202
                    operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        )

        self._feature_size = model_channels
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for nr in range(self.num_res_blocks[level]):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
220
221
222
                        dtype=self.dtype,
                        device=device,
                        operations=operations,
comfyanonymous's avatar
comfyanonymous committed
223
224
225
                    )
                ]
                ch = mult * model_channels
226
227
                num_transformers = transformer_depth.pop(0)
                if num_transformers > 0:
comfyanonymous's avatar
comfyanonymous committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    if exists(disable_self_attentions):
                        disabled_sa = disable_self_attentions[level]
                    else:
                        disabled_sa = False

                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
                        layers.append(
comfyanonymous's avatar
comfyanonymous committed
243
                            SpatialTransformer(
244
                                ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
comfyanonymous's avatar
comfyanonymous committed
245
                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
246
                                use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
comfyanonymous's avatar
comfyanonymous committed
247
248
249
                            )
                        )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
250
                self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
comfyanonymous's avatar
comfyanonymous committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
266
267
                            dtype=self.dtype,
                            device=device,
comfyanonymous's avatar
comfyanonymous committed
268
                            operations=operations
comfyanonymous's avatar
comfyanonymous committed
269
270
271
                        )
                        if resblock_updown
                        else Downsample(
272
                            ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
comfyanonymous's avatar
comfyanonymous committed
273
274
275
276
277
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
278
                self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
comfyanonymous's avatar
comfyanonymous committed
279
280
281
282
283
284
285
286
287
288
289
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            #num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
290
        mid_block = [
comfyanonymous's avatar
comfyanonymous committed
291
292
293
294
295
296
297
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
298
299
                dtype=self.dtype,
                device=device,
comfyanonymous's avatar
comfyanonymous committed
300
                operations=operations
301
302
303
            )]
        if transformer_depth_middle >= 0:
            mid_block += [SpatialTransformer(  # always uses a self-attn
304
                            ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
comfyanonymous's avatar
comfyanonymous committed
305
                            disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
306
                            use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
comfyanonymous's avatar
comfyanonymous committed
307
308
309
310
311
312
313
314
                        ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
315
316
                dtype=self.dtype,
                device=device,
comfyanonymous's avatar
comfyanonymous committed
317
                operations=operations
318
319
            )]
        self.middle_block = TimestepEmbedSequential(*mid_block)
320
        self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
321
322
        self._feature_size += ch

323
324
        if union_controlnet_num_control_type is not None:
            self.num_control_type = union_controlnet_num_control_type
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            num_trans_channel = 320
            num_trans_head = 8
            num_trans_layer = 1
            num_proj_channel = 320
            # task_scale_factor = num_trans_channel ** 0.5
            self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))

            self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
            self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
            #-----------------------------------------------------------------------------------------------------

            control_add_embed_dim = 256
            class ControlAddEmbedding(nn.Module):
                def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
                    super().__init__()
                    self.num_control_type = num_control_type
                    self.in_dim = in_dim
                    self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
                    self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
                def forward(self, control_type, dtype, device):
                    c_type = torch.zeros((self.num_control_type,), device=device)
                    c_type[control_type] = 1.0
                    c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
                    return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))

            self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
        else:
            self.task_embedding = None
            self.control_add_embedding = None

    def union_controlnet_merge(self, hint, control_type, emb, context):
        # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
        inputs = []
        condition_list = []

        for idx in range(min(1, len(control_type))):
            controlnet_cond = self.input_hint_block(hint[idx], emb, context)
            feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
            if idx < len(control_type):
364
                feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381

            inputs.append(feat_seq.unsqueeze(1))
            condition_list.append(controlnet_cond)

        x = torch.cat(inputs, dim=1)
        x = self.transformer_layes(x)
        controlnet_cond_fuser = None
        for idx in range(len(control_type)):
            alpha = self.spatial_ch_projs(x[:, idx])
            alpha = alpha.unsqueeze(-1).unsqueeze(-1)
            o = condition_list[idx] + alpha
            if controlnet_cond_fuser is None:
                controlnet_cond_fuser = o
            else:
                controlnet_cond_fuser += o
        return controlnet_cond_fuser

382
383
    def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
        return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
comfyanonymous's avatar
comfyanonymous committed
384

385
    def forward(self, x, hint, timesteps, context, y=None, **kwargs):
386
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
comfyanonymous's avatar
comfyanonymous committed
387
388
        emb = self.time_embed(t_emb)

389
        guided_hint = None
comfyanonymous's avatar
comfyanonymous committed
390
        if self.control_add_embedding is not None: #Union Controlnet
391
392
393
394
395
396
397
398
399
400
            control_type = kwargs.get("control_type", [])

            emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
            if len(control_type) > 0:
                if len(hint.shape) < 5:
                    hint = hint.unsqueeze(dim=0)
                guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)

        if guided_hint is None:
            guided_hint = self.input_hint_block(hint, emb, context)
comfyanonymous's avatar
comfyanonymous committed
401

comfyanonymous's avatar
comfyanonymous committed
402
403
        out_output = []
        out_middle = []
comfyanonymous's avatar
comfyanonymous committed
404

405
406
407
408
409
        hs = []
        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

410
        h = x
comfyanonymous's avatar
comfyanonymous committed
411
412
413
414
415
416
417
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            if guided_hint is not None:
                h = module(h, emb, context)
                h += guided_hint
                guided_hint = None
            else:
                h = module(h, emb, context)
comfyanonymous's avatar
comfyanonymous committed
418
            out_output.append(zero_conv(h, emb, context))
comfyanonymous's avatar
comfyanonymous committed
419
420

        h = self.middle_block(h, emb, context)
comfyanonymous's avatar
comfyanonymous committed
421
        out_middle.append(self.middle_block_out(h, emb, context))
comfyanonymous's avatar
comfyanonymous committed
422

comfyanonymous's avatar
comfyanonymous committed
423
        return {"middle": out_middle, "output": out_output}
comfyanonymous's avatar
comfyanonymous committed
424