model.py 27.9 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
8
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from typing import Optional, Any

comfyanonymous's avatar
comfyanonymous committed
9
from ..attention import MemoryEfficientCrossAttention
10
from comfy import model_management
comfyanonymous's avatar
comfyanonymous committed
11
import comfy.ops
comfyanonymous's avatar
comfyanonymous committed
12

13
if model_management.xformers_enabled_vae():
comfyanonymous's avatar
comfyanonymous committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    import xformers
    import xformers.ops

def get_timestep_embedding(timesteps, embedding_dim):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0,1,0,0))
    return emb


def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)


def Normalize(in_channels, num_groups=32):
    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)


class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
comfyanonymous's avatar
comfyanonymous committed
52
            self.conv = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
53
54
55
56
57
58
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
59
60
61
62
63
64
65
66
67
68
69
70
        try:
            x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        except: #operation not implemented for bf16
            b, c, h, w = x.shape
            out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
            split = 8
            l = out.shape[1] // split
            for i in range(0, out.shape[1], l):
                out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
            del x
            x = out

comfyanonymous's avatar
comfyanonymous committed
71
72
73
74
75
76
77
78
79
80
81
        if self.with_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
comfyanonymous's avatar
comfyanonymous committed
82
            self.conv = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
83
84
85
86
87
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

88
    def forward(self, x):
comfyanonymous's avatar
comfyanonymous committed
89
        if self.with_conv:
90
91
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
comfyanonymous's avatar
comfyanonymous committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

107
        self.swish = torch.nn.SiLU(inplace=True)
comfyanonymous's avatar
comfyanonymous committed
108
        self.norm1 = Normalize(in_channels)
comfyanonymous's avatar
comfyanonymous committed
109
        self.conv1 = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
110
111
112
113
114
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if temb_channels > 0:
comfyanonymous's avatar
comfyanonymous committed
115
            self.temb_proj = comfy.ops.Linear(temb_channels,
comfyanonymous's avatar
comfyanonymous committed
116
117
                                             out_channels)
        self.norm2 = Normalize(out_channels)
118
        self.dropout = torch.nn.Dropout(dropout, inplace=True)
comfyanonymous's avatar
comfyanonymous committed
119
        self.conv2 = comfy.ops.Conv2d(out_channels,
comfyanonymous's avatar
comfyanonymous committed
120
121
122
123
124
125
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
comfyanonymous's avatar
comfyanonymous committed
126
                self.conv_shortcut = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
127
128
129
130
131
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
comfyanonymous's avatar
comfyanonymous committed
132
                self.nin_shortcut = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
133
134
135
136
137
138
139
140
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb):
        h = x
        h = self.norm1(h)
141
        h = self.swish(h)
comfyanonymous's avatar
comfyanonymous committed
142
143
144
        h = self.conv1(h)

        if temb is not None:
145
            h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
comfyanonymous's avatar
comfyanonymous committed
146
147

        h = self.norm2(h)
148
        h = self.swish(h)
comfyanonymous's avatar
comfyanonymous committed
149
150
151
152
153
154
155
156
157
158
159
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def slice_attention(q, k, v):
    r1 = torch.zeros_like(k, device=q.device)
    scale = (int(q.shape[-1])**(-0.5))

    mem_free_total = model_management.get_free_memory(q.device)

    gb = 1024 ** 3
    tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
    modifier = 3 if q.element_size() == 2 else 2.5
    mem_required = tensor_size * modifier
    steps = 1

    if mem_required > mem_free_total:
        steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))

    while True:
        try:
            slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
            for i in range(0, q.shape[1], slice_size):
                end = i + slice_size
                s1 = torch.bmm(q[:, i:end], k) * scale

                s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
                del s1

                r1[:, :, i:end] = torch.bmm(v, s2)
                del s2
            break
        except model_management.OOM_EXCEPTION as e:
            steps *= 2
            if steps > 128:
                raise e
            print("out of memory error, increasing steps and trying again", steps)

    return r1
comfyanonymous's avatar
comfyanonymous committed
195
196
197
198
199
200
201

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
comfyanonymous's avatar
comfyanonymous committed
202
        self.q = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
203
204
205
206
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
207
        self.k = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
208
209
210
211
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
212
        self.v = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
213
214
215
216
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
217
        self.proj_out = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
232

comfyanonymous's avatar
comfyanonymous committed
233
234
235
236
        q = q.reshape(b,c,h*w)
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        v = v.reshape(b,c,h*w)
237

238
        r1 = slice_attention(q, k, v)
239
240
        h_ = r1.reshape(b,c,h,w)
        del r1
comfyanonymous's avatar
comfyanonymous committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        h_ = self.proj_out(h_)

        return x+h_

class MemoryEfficientAttnBlock(nn.Module):
    """
        Uses xformers efficient implementation,
        see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
        Note: this is a single-head self-attention operation
    """
    #
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
comfyanonymous's avatar
comfyanonymous committed
257
        self.q = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
258
259
260
261
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
262
        self.k = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
263
264
265
266
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
267
        self.v = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
268
269
270
271
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
272
        self.proj_out = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)
        self.attention_op: Optional[Any] = None

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        B, C, H, W = q.shape
        q, k, v = map(
289
            lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
comfyanonymous's avatar
comfyanonymous committed
290
291
292
            (q, k, v),
        )

293
294
295
296
297
298
        try:
            out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
            out = out.transpose(1, 2).reshape(B, C, H, W)
        except NotImplementedError as e:
            out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)

comfyanonymous's avatar
comfyanonymous committed
299
300
301
        out = self.proj_out(out)
        return x+out

302
303
304
305
306
307
class MemoryEfficientAttnBlockPytorch(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
comfyanonymous's avatar
comfyanonymous committed
308
        self.q = comfy.ops.Conv2d(in_channels,
309
310
311
312
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
313
        self.k = comfy.ops.Conv2d(in_channels,
314
315
316
317
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
318
        self.v = comfy.ops.Conv2d(in_channels,
319
320
321
322
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
comfyanonymous's avatar
comfyanonymous committed
323
        self.proj_out = comfy.ops.Conv2d(in_channels,
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)
        self.attention_op: Optional[Any] = None

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        B, C, H, W = q.shape
        q, k, v = map(
340
            lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
341
342
343
            (q, k, v),
        )

344
345
346
347
348
349
350
        try:
            out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
            out = out.transpose(2, 3).reshape(B, C, H, W)
        except model_management.OOM_EXCEPTION as e:
            print("scaled_dot_product_attention OOMed: switched to slice attention")
            out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)

351
352
        out = self.proj_out(out)
        return x+out
comfyanonymous's avatar
comfyanonymous committed
353
354
355
356
357
358
359
360
361
362
363
364

class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
    def forward(self, x, context=None, mask=None):
        b, c, h, w = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        out = super().forward(x, context=context, mask=mask)
        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
        return x + out


def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
    assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
365
    if model_management.xformers_enabled_vae() and attn_type == "vanilla":
comfyanonymous's avatar
comfyanonymous committed
366
        attn_type = "vanilla-xformers"
367
368
    if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
        attn_type = "vanilla-pytorch"
comfyanonymous's avatar
comfyanonymous committed
369
370
371
372
373
374
375
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        assert attn_kwargs is None
        return AttnBlock(in_channels)
    elif attn_type == "vanilla-xformers":
        print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
        return MemoryEfficientAttnBlock(in_channels)
376
377
    elif attn_type == "vanilla-pytorch":
        return MemoryEfficientAttnBlockPytorch(in_channels)
comfyanonymous's avatar
comfyanonymous committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    elif type == "memory-efficient-cross-attn":
        attn_kwargs["query_dim"] = in_channels
        return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
    elif attn_type == "none":
        return nn.Identity(in_channels)
    else:
        raise NotImplementedError()


class Model(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = self.ch*4
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        self.use_timestep = use_timestep
        if self.use_timestep:
            # timestep embedding
            self.temb = nn.Module()
            self.temb.dense = nn.ModuleList([
comfyanonymous's avatar
comfyanonymous committed
405
                comfy.ops.Linear(self.ch,
comfyanonymous's avatar
comfyanonymous committed
406
                                self.temb_ch),
comfyanonymous's avatar
comfyanonymous committed
407
                comfy.ops.Linear(self.temb_ch,
comfyanonymous's avatar
comfyanonymous committed
408
409
410
411
                                self.temb_ch),
            ])

        # downsampling
comfyanonymous's avatar
comfyanonymous committed
412
        self.conv_in = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            skip_in = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                if i_block == self.num_res_blocks:
                    skip_in = ch*in_ch_mult[i_level]
                block.append(ResnetBlock(in_channels=block_in+skip_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
comfyanonymous's avatar
comfyanonymous committed
481
        self.conv_out = comfy.ops.Conv2d(block_in,
comfyanonymous's avatar
comfyanonymous committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x, t=None, context=None):
        #assert x.shape[2] == x.shape[3] == self.resolution
        if context is not None:
            # assume aligned context, cat along channel axis
            x = torch.cat((x, context), dim=1)
        if self.use_timestep:
            # timestep embedding
            assert t is not None
            temb = get_timestep_embedding(t, self.ch)
            temb = self.temb.dense[0](temb)
            temb = nonlinearity(temb)
            temb = self.temb.dense[1](temb)
        else:
            temb = None

        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](
                    torch.cat([h, hs.pop()], dim=1), temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h

    def get_last_layer(self):
        return self.conv_out.weight


class Encoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 **ignore_kwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # downsampling
comfyanonymous's avatar
comfyanonymous committed
554
        self.conv_in = comfy.ops.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # end
        self.norm_out = Normalize(block_in)
comfyanonymous's avatar
comfyanonymous committed
599
        self.conv_out = comfy.ops.Conv2d(block_in,
comfyanonymous's avatar
comfyanonymous committed
600
601
602
603
604
605
606
607
608
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        # timestep embedding
        temb = None
        # downsampling
609
        h = self.conv_in(x)
comfyanonymous's avatar
comfyanonymous committed
610
611
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
612
                h = self.down[i_level].block[i_block](h, temb)
comfyanonymous's avatar
comfyanonymous committed
613
614
615
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
            if i_level != self.num_resolutions-1:
616
                h = self.down[i_level].downsample(h)
comfyanonymous's avatar
comfyanonymous committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654

        # middle
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h


class Decoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 attn_type="vanilla", **ignorekwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.tanh_out = tanh_out

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
comfyanonymous's avatar
comfyanonymous committed
655
        self.conv_in = comfy.ops.Conv2d(z_channels,
comfyanonymous's avatar
comfyanonymous committed
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
comfyanonymous's avatar
comfyanonymous committed
697
        self.conv_out = comfy.ops.Conv2d(block_in,
comfyanonymous's avatar
comfyanonymous committed
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, z):
        #assert z.shape[1:] == self.z_shape[1:]
        self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        # z to block_in
        h = self.conv_in(z)

        # middle
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        if self.tanh_out:
            h = torch.tanh(h)
        return h