resnet.py 23.5 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
from abc import abstractmethod

import numpy as np
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torch.nn.functional as F


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

patil-suraj's avatar
patil-suraj committed
34

35
36
37
38
39
40
41
42
43
44
45
46
47
def conv_transpose_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.ConvTranspose1d(*args, **kwargs)
    elif dims == 2:
        return nn.ConvTranspose2d(*args, **kwargs)
    elif dims == 3:
        return nn.ConvTranspose3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


48
49
def Normalize(in_channels, num_groups=32, eps=1e-6):
    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
50
51
52
53
54
55
56
57
58
59


def nonlinearity(x, swish=1.0):
    # swish
    if swish == 1.0:
        return F.silu(x)
    else:
        return x * F.sigmoid(x * float(swish))


Patrick von Platen's avatar
Patrick von Platen committed
60
61
62
63
64
65
66
67
68
69
70
71
class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


72
73
74
75
class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
76
77
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
78
79
80
                 upsampling occurs in the inner-two dimensions.
    """

patil-suraj's avatar
patil-suraj committed
81
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
82
83
84
85
86
87
88
89
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        self.use_conv_transpose = use_conv_transpose

        if use_conv_transpose:
patil-suraj's avatar
patil-suraj committed
90
            self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
91
92
93
94
95
96
97
        elif use_conv:
            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)
patil-suraj's avatar
patil-suraj committed
98

99
100
101
102
        if self.dims == 3:
            x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
        else:
            x = F.interpolate(x, scale_factor=2.0, mode="nearest")
patil-suraj's avatar
patil-suraj committed
103

104
105
        if self.use_conv:
            x = self.conv(x)
patil-suraj's avatar
patil-suraj committed
106

107
108
109
110
111
112
113
        return x


class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
114
115
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
116
117
118
                 downsampling occurs in the inner-two dimensions.
    """

patil-suraj's avatar
patil-suraj committed
119
    def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
120
121
122
123
124
125
126
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        self.padding = padding
        stride = 2 if dims != 3 else (1, 2, 2)
patil-suraj's avatar
patil-suraj committed
127
128
        self.name = name

129
        if use_conv:
patil-suraj's avatar
patil-suraj committed
130
            conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
131
132
        else:
            assert self.channels == self.out_channels
patil-suraj's avatar
patil-suraj committed
133
134
135
136
137
138
            conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)

        if name == "conv":
            self.conv = conv
        else:
            self.op = conv
139
140
141
142
143
144

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv and self.padding == 0 and self.dims == 2:
            pad = (0, 1, 0, 1)
            x = F.pad(x, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
145
146
147
148
149

        if self.name == "conv":
            return self.conv(x)
        else:
            return self.op(x)
150
151


Patrick von Platen's avatar
Patrick von Platen committed
152
153
154
155
156
157
158
159
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
#    def __init__(self, dim):
#        super().__init__()
#        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
#    def forward(self, x):
#        return self.conv(x)
160
161


Patrick von Platen's avatar
Patrick von Platen committed
162
# RESNETS
Patrick von Platen's avatar
Patrick von Platen committed
163

Patrick von Platen's avatar
Patrick von Platen committed
164
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
165
class ResnetBlock(nn.Module):
Patrick von Platen's avatar
Patrick von Platen committed
166
167
168
169
170
171
172
173
174
175
176
177
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
Patrick von Platen's avatar
Patrick von Platen committed
178
179
180
        time_embedding_norm="default",
        up=False,
        down=False,
Patrick von Platen's avatar
Patrick von Platen committed
181
        overwrite_for_grad_tts=False,
Patrick von Platen's avatar
up  
Patrick von Platen committed
182
        overwrite_for_ldm=False,
Patrick von Platen's avatar
Patrick von Platen committed
183
        overwrite_for_glide=False,
Patrick von Platen's avatar
Patrick von Platen committed
184
    ):
185
186
187
188
189
190
        super().__init__()
        self.pre_norm = pre_norm
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
Patrick von Platen's avatar
Patrick von Platen committed
191
192
193
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
194
195
196
197
198
199
200

        if self.pre_norm:
            self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
        else:
            self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
201
202
203

        if time_embedding_norm == "default":
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
Patrick von Platen's avatar
Patrick von Platen committed
204
        elif time_embedding_norm == "scale_shift":
Patrick von Platen's avatar
Patrick von Platen committed
205
206
            self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)

207
208
209
        self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
Patrick von Platen's avatar
up  
Patrick von Platen committed
210

211
212
213
214
        if non_linearity == "swish":
            self.nonlinearity = nonlinearity
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
Patrick von Platen's avatar
up  
Patrick von Platen committed
215
216
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()
217

Patrick von Platen's avatar
Patrick von Platen committed
218
219
220
221
222
223
224
        if up:
            self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
            self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
        elif down:
            self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
            self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")

225
        if self.in_channels != self.out_channels:
Patrick von Platen's avatar
Patrick von Platen committed
226
            self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
227

Patrick von Platen's avatar
Patrick von Platen committed
228
        # TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
229
        self.is_overwritten = False
Patrick von Platen's avatar
Patrick von Platen committed
230
        self.overwrite_for_glide = overwrite_for_glide
231
        self.overwrite_for_grad_tts = overwrite_for_grad_tts
Patrick von Platen's avatar
Patrick von Platen committed
232
        self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
233
234
235
236
237
238
239
240
241
242
243
244
245
        if self.overwrite_for_grad_tts:
            dim = in_channels
            dim_out = out_channels
            time_emb_dim = temb_channels
            self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
            self.pre_norm = pre_norm

            self.block1 = Block(dim, dim_out, groups=groups)
            self.block2 = Block(dim_out, dim_out, groups=groups)
            if dim != dim_out:
                self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
            else:
                self.res_conv = torch.nn.Identity()
Patrick von Platen's avatar
up  
Patrick von Platen committed
246
247
        elif self.overwrite_for_ldm:
            dims = 2
Patrick von Platen's avatar
up  
Patrick von Platen committed
248
249
250
            #            eps = 1e-5
            #            non_linearity = "silu"
            #            overwrite_for_ldm
Patrick von Platen's avatar
up  
Patrick von Platen committed
251
252
253
254
255
256
257
258
259
260
261
262
263
            channels = in_channels
            emb_channels = temb_channels
            use_scale_shift_norm = False

            self.in_layers = nn.Sequential(
                normalization(channels, swish=1.0),
                nn.Identity(),
                conv_nd(dims, channels, self.out_channels, 3, padding=1),
            )
            self.emb_layers = nn.Sequential(
                nn.SiLU(),
                linear(
                    emb_channels,
Patrick von Platen's avatar
Patrick von Platen committed
264
                    2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
Patrick von Platen's avatar
up  
Patrick von Platen committed
265
266
267
268
269
270
271
272
273
274
275
276
                ),
            )
            self.out_layers = nn.Sequential(
                normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
                nn.SiLU() if use_scale_shift_norm else nn.Identity(),
                nn.Dropout(p=dropout),
                zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
            )
            if self.out_channels == in_channels:
                self.skip_connection = nn.Identity()
            else:
                self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    def set_weights_grad_tts(self):
        self.conv1.weight.data = self.block1.block[0].weight.data
        self.conv1.bias.data = self.block1.block[0].bias.data
        self.norm1.weight.data = self.block1.block[1].weight.data
        self.norm1.bias.data = self.block1.block[1].bias.data

        self.conv2.weight.data = self.block2.block[0].weight.data
        self.conv2.bias.data = self.block2.block[0].bias.data
        self.norm2.weight.data = self.block2.block[1].weight.data
        self.norm2.bias.data = self.block2.block[1].bias.data

        self.temb_proj.weight.data = self.mlp[1].weight.data
        self.temb_proj.bias.data = self.mlp[1].bias.data

        if self.in_channels != self.out_channels:
            self.nin_shortcut.weight.data = self.res_conv.weight.data
            self.nin_shortcut.bias.data = self.res_conv.bias.data

Patrick von Platen's avatar
up  
Patrick von Platen committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    def set_weights_ldm(self):
        self.norm1.weight.data = self.in_layers[0].weight.data
        self.norm1.bias.data = self.in_layers[0].bias.data

        self.conv1.weight.data = self.in_layers[-1].weight.data
        self.conv1.bias.data = self.in_layers[-1].bias.data

        self.temb_proj.weight.data = self.emb_layers[-1].weight.data
        self.temb_proj.bias.data = self.emb_layers[-1].bias.data

        self.norm2.weight.data = self.out_layers[0].weight.data
        self.norm2.bias.data = self.out_layers[0].bias.data

        self.conv2.weight.data = self.out_layers[-1].weight.data
        self.conv2.bias.data = self.out_layers[-1].bias.data

        if self.in_channels != self.out_channels:
            self.nin_shortcut.weight.data = self.skip_connection.weight.data
            self.nin_shortcut.bias.data = self.skip_connection.bias.data

    def forward(self, x, temb, mask=1.0):
Patrick von Platen's avatar
Patrick von Platen committed
317
318
        # TODO(Patrick) eventually this class should be split into multiple classes
        # too many if else statements
319
320
321
        if self.overwrite_for_grad_tts and not self.is_overwritten:
            self.set_weights_grad_tts()
            self.is_overwritten = True
Patrick von Platen's avatar
up  
Patrick von Platen committed
322
323
324
        elif self.overwrite_for_ldm and not self.is_overwritten:
            self.set_weights_ldm()
            self.is_overwritten = True
325
326

        h = x
Patrick von Platen's avatar
up  
Patrick von Platen committed
327
        h = h * mask
328
329
330
331
        if self.pre_norm:
            h = self.norm1(h)
            h = self.nonlinearity(h)

Patrick von Platen's avatar
Patrick von Platen committed
332
        if self.up or self.down:
Patrick von Platen's avatar
Patrick von Platen committed
333
            x = self.x_upd(x)
Patrick von Platen's avatar
Patrick von Platen committed
334
335
            h = self.h_upd(h)

336
337
338
339
340
        h = self.conv1(h)

        if not self.pre_norm:
            h = self.norm1(h)
            h = self.nonlinearity(h)
Patrick von Platen's avatar
up  
Patrick von Platen committed
341
        h = h * mask
342

Patrick von Platen's avatar
Patrick von Platen committed
343
        temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
Patrick von Platen's avatar
Patrick von Platen committed
344

Patrick von Platen's avatar
Patrick von Platen committed
345
346
        if self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
347
348

            h = self.norm2(h)
Patrick von Platen's avatar
Patrick von Platen committed
349
            h = h + h * scale + shift
350
            h = self.nonlinearity(h)
Patrick von Platen's avatar
Patrick von Platen committed
351
352
353
354
355
356
        elif self.time_embedding_norm == "default":
            h = h + temb
            h = h * mask
            if self.pre_norm:
                h = self.norm2(h)
                h = self.nonlinearity(h)
357
358
359
360
361
362
363

        h = self.dropout(h)
        h = self.conv2(h)

        if not self.pre_norm:
            h = self.norm2(h)
            h = self.nonlinearity(h)
Patrick von Platen's avatar
up  
Patrick von Platen committed
364
        h = h * mask
365

Patrick von Platen's avatar
up  
Patrick von Platen committed
366
        x = x * mask
367
        if self.in_channels != self.out_channels:
Patrick von Platen's avatar
Patrick von Platen committed
368
            x = self.nin_shortcut(x)
369
370
371
372

        return x + h


Patrick von Platen's avatar
finish  
Patrick von Platen committed
373
# TODO(Patrick) - just there to convert the weights; can delete afterward
374
375
376
377
378
class Block(torch.nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super(Block, self).__init__()
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
Patrick von Platen's avatar
Patrick von Platen committed
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        )


# unet_score_estimation.py
class ResnetBlockBigGANpp(nn.Module):
    def __init__(
        self,
        act,
        in_ch,
        out_ch=None,
        temb_dim=None,
        up=False,
        down=False,
        dropout=0.1,
        fir_kernel=(1, 3, 3, 1),
        skip_rescale=True,
        init_scale=0.0,
    ):
        super().__init__()

        out_ch = out_ch if out_ch else in_ch
        self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
        self.up = up
        self.down = down
        self.fir_kernel = fir_kernel

patil-suraj's avatar
patil-suraj committed
405
        self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
406
407
        if temb_dim is not None:
            self.Dense_0 = nn.Linear(temb_dim, out_ch)
patil-suraj's avatar
patil-suraj committed
408
            self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
Patrick von Platen's avatar
Patrick von Platen committed
409
410
411
412
            nn.init.zeros_(self.Dense_0.bias)

        self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
        self.Dropout_0 = nn.Dropout(dropout)
patil-suraj's avatar
patil-suraj committed
413
        self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
414
        if in_ch != out_ch or up or down:
patil-suraj's avatar
style  
patil-suraj committed
415
            # 1x1 convolution with DDPM initialization.
patil-suraj's avatar
patil-suraj committed
416
            self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
Patrick von Platen's avatar
Patrick von Platen committed
417
418
419
420
421
422
423
424
425
426

        self.skip_rescale = skip_rescale
        self.act = act
        self.in_ch = in_ch
        self.out_ch = out_ch

    def forward(self, x, temb=None):
        h = self.act(self.GroupNorm_0(x))

        if self.up:
427
428
            h = upsample_2d(h, self.fir_kernel, factor=2)
            x = upsample_2d(x, self.fir_kernel, factor=2)
Patrick von Platen's avatar
Patrick von Platen committed
429
        elif self.down:
430
431
            h = downsample_2d(h, self.fir_kernel, factor=2)
            x = downsample_2d(x, self.fir_kernel, factor=2)
Patrick von Platen's avatar
Patrick von Platen committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

        h = self.Conv_0(h)
        # Add bias to each feature map conditioned on the time embedding
        if temb is not None:
            h += self.Dense_0(self.act(temb))[:, :, None, None]
        h = self.act(self.GroupNorm_1(h))
        h = self.Dropout_0(h)
        h = self.Conv_1(h)

        if self.in_ch != self.out_ch or self.up or self.down:
            x = self.Conv_2(x)

        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.0)


450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
        super().__init__()

        self.blocks = nn.ModuleList(
            [
                Conv1dBlock(inp_channels, out_channels, kernel_size),
                Conv1dBlock(out_channels, out_channels, kernel_size),
            ]
        )

        self.time_mlp = nn.Sequential(
            nn.Mish(),
            nn.Linear(embed_dim, out_channels),
            RearrangeDim(),
            #            Rearrange("batch t -> batch t 1"),
        )

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

    def forward(self, x, t):
        """
        x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
        out_channels x horizon ]
        """
        out = self.blocks[0](x) + self.time_mlp(t)
        out = self.blocks[1](out)
        return out + self.residual_conv(x)


Patrick von Platen's avatar
Patrick von Platen committed
483
484
485
486
# HELPER Modules


def normalization(channels, swish=0.0):
487
    """
Patrick von Platen's avatar
Patrick von Platen committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    Make a standard normalization layer, with an optional swish activation.

    :param channels: number of input channels. :return: an nn.Module for normalization.
    """
    return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)


class GroupNorm32(nn.GroupNorm):
    def __init__(self, num_groups, num_channels, swish, eps=1e-5):
        super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
        self.swish = swish

    def forward(self, x):
        y = super().forward(x.float()).to(x.dtype)
        if self.swish == 1.0:
            y = F.silu(y)
        elif self.swish:
            y = y * F.sigmoid(y * float(self.swish))
        return y


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
519
    """
Patrick von Platen's avatar
Patrick von Platen committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    for p in module.parameters():
        p.detach().zero_()
    return module


class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
536
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
537
538
539
540
541
542
543
544
545
546

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            RearrangeDim(),
            #            Rearrange("batch channels horizon -> batch channels 1 horizon"),
            nn.GroupNorm(n_groups, out_channels),
            RearrangeDim(),
            #            Rearrange("batch channels 1 horizon -> batch channels horizon"),
            nn.Mish(),
        )
547
548

    def forward(self, x):
Patrick von Platen's avatar
Patrick von Platen committed
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        return self.block(x)


class RearrangeDim(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, tensor):
        if len(tensor.shape) == 2:
            return tensor[:, :, None]
        if len(tensor.shape) == 3:
            return tensor[:, :, None, :]
        elif len(tensor.shape) == 4:
            return tensor[:, :, 0, :]
563
        else:
Patrick von Platen's avatar
Patrick von Platen committed
564
565
566
            raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


patil-suraj's avatar
patil-suraj committed
567
568
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
    """nXn convolution with DDPM initialization."""
patil-suraj's avatar
style  
patil-suraj committed
569
    conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
patil-suraj's avatar
patil-suraj committed
570
    conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
Patrick von Platen's avatar
Patrick von Platen committed
571
572
573
574
    nn.init.zeros_(conv.bias)
    return conv


patil-suraj's avatar
patil-suraj committed
575
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
Patrick von Platen's avatar
Patrick von Platen committed
576
    """Ported from JAX."""
patil-suraj's avatar
patil-suraj committed
577
    scale = 1e-10 if scale == 0 else scale
Patrick von Platen's avatar
Patrick von Platen committed
578
579
580
581
582
583
584
585
586

    def _compute_fans(shape, in_axis=1, out_axis=0):
        receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
        fan_in = shape[in_axis] * receptive_field_size
        fan_out = shape[out_axis] * receptive_field_size
        return fan_in, fan_out

    def init(shape, dtype=dtype, device=device):
        fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
patil-suraj's avatar
patil-suraj committed
587
        denominator = (fan_in + fan_out) / 2
Patrick von Platen's avatar
Patrick von Platen committed
588
        variance = scale / denominator
patil-suraj's avatar
patil-suraj committed
589
        return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
590

Patrick von Platen's avatar
Patrick von Platen committed
591
    return init
592
593


Patrick von Platen's avatar
Patrick von Platen committed
594
595
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
596
597


Patrick von Platen's avatar
Patrick von Platen committed
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)


def upsample_2d(x, k=None, factor=2, gain=1):
    r"""Upsample a batch of 2D images with the given filter.

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
    multiple of the upsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]`
    """
    assert isinstance(factor, int) and factor >= 1
    if k is None:
        k = [1] * factor
    k = _setup_kernel(k) * (gain * (factor**2))
    p = k.shape[0] - factor
    return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))


def downsample_2d(x, k=None, factor=2, gain=1):
    r"""Downsample a batch of 2D images with the given filter.

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
        factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]`
    """

    assert isinstance(factor, int) and factor >= 1
    if k is None:
        k = [1] * factor
    k = _setup_kernel(k) * gain
    p = k.shape[0] - factor
    return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))


def _setup_kernel(k):
    k = np.asarray(k, dtype=np.float32)
    if k.ndim == 1:
        k = np.outer(k, k)
    k /= np.sum(k)
    assert k.ndim == 2
    assert k.shape[0] == k.shape[1]
    return k