resnet.py 30.8 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2
3

import numpy as np
4
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F


9
class Upsample2D(nn.Module):
10
11
12
    """
    An upsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
13
14
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
15
16
17
                 upsampling occurs in the inner-two dimensions.
    """

18
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
19
20
21
22
23
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
24
        self.name = name
25

patil-suraj's avatar
patil-suraj committed
26
        conv = None
27
        if use_conv_transpose:
28
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
29
        elif use_conv:
30
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
31

32
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
33
34
35
36
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
37
38
39
40
41

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)
patil-suraj's avatar
patil-suraj committed
42

43
        x = F.interpolate(x, scale_factor=2.0, mode="nearest")
patil-suraj's avatar
patil-suraj committed
44

45
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
46
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
47
48
49
50
            if self.name == "conv":
                x = self.conv(x)
            else:
                x = self.Conv2d_0(x)
patil-suraj's avatar
patil-suraj committed
51

52
53
54
        return x


55
class Downsample2D(nn.Module):
56
57
58
    """
    A downsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
59
60
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
61
62
63
                 downsampling occurs in the inner-two dimensions.
    """

64
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
65
66
67
68
69
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
70
        stride = 2
patil-suraj's avatar
patil-suraj committed
71
72
        self.name = name

73
        if use_conv:
74
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
75
76
        else:
            assert self.channels == self.out_channels
77
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
78

79
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
80
81
        if name == "conv":
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
82
83
        elif name == "Conv2d_0":
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
84
85
        else:
            self.op = conv
86
87
88

    def forward(self, x):
        assert x.shape[1] == self.channels
89
        if self.use_conv and self.padding == 0:
90
91
            pad = (0, 1, 0, 1)
            x = F.pad(x, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
92

93
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
94
95
        if self.name == "conv":
            return self.conv(x)
patil-suraj's avatar
patil-suraj committed
96
97
        elif self.name == "Conv2d_0":
            return self.Conv2d_0(x)
patil-suraj's avatar
patil-suraj committed
98
99
        else:
            return self.op(x)
100
101


102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
class Upsample1D(nn.Module):
    """
    An upsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 upsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)

        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

        if self.use_conv:
            x = self.conv(x)

        return x


class Downsample1D(nn.Module):
    """
    A downsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.conv(x)


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

    def forward(self, x):
        if self.use_conv:
            h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
            h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
        else:
            h = upsample_2d(x, self.fir_kernel, factor=2)

        return h


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

    def forward(self, x):
        if self.use_conv:
            x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
            x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
        else:
            x = downsample_2d(x, self.fir_kernel, factor=2)

        return x


def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
    """Fused `Conv2d()` followed by `downsample_2d()`.

    Args:
    Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
    efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
    order.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        w: Weight tensor of the shape `[filterH, filterW, inChannels,
          outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
        factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
        as `x`.
    """

    assert isinstance(factor, int) and factor >= 1
    _outC, _inC, convH, convW = w.shape
    assert convW == convH
    if k is None:
        k = [1] * factor
    k = _setup_kernel(k) * gain
    p = (k.shape[0] - factor) + (convW - 1)
    s = [factor, factor]
    x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
    return F.conv2d(x, w, stride=s, padding=0)


def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
    """Fused `upsample_2d()` followed by `Conv2d()`.

    Args:
    Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
    efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
    order.
      x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
        C]`.
      w: Weight tensor of the shape `[filterH, filterW, inChannels,
        outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
      k: FIR filter of the shape `[firH, firW]` or `[firN]`
        (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
      factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
      Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
      `x`.
    """

    assert isinstance(factor, int) and factor >= 1

    # Check weight shape.
    assert len(w.shape) == 4
    convH = w.shape[2]
    convW = w.shape[3]
    inC = w.shape[1]

    assert convW == convH

    # Setup filter kernel.
    if k is None:
        k = [1] * factor
    k = _setup_kernel(k) * (gain * (factor**2))
    p = (k.shape[0] - factor) - (convW - 1)

    stride = (factor, factor)

    # Determine data dimensions.
    stride = [1, 1, factor, factor]
    output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
    output_padding = (
        output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
        output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
    )
    assert output_padding[0] >= 0 and output_padding[1] >= 0
    num_groups = x.shape[1] // inC

    # Transpose weights.
    w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
    w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
    w = torch.reshape(w, (num_groups * inC, -1, convH, convW))

    x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)

    return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))


Patrick von Platen's avatar
Patrick von Platen committed
298
# TODO (patil-suraj): needs test
299
# class Upsample2D1d(nn.Module):
Patrick von Platen's avatar
Patrick von Platen committed
300
301
302
303
304
305
#    def __init__(self, dim):
#        super().__init__()
#        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
#    def forward(self, x):
#        return self.conv(x)
306
307


Patrick von Platen's avatar
update  
Patrick von Platen committed
308
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
Patrick von Platen's avatar
Patrick von Platen committed
309
# => All 2D-Resnets are included here now!
Patrick von Platen's avatar
Patrick von Platen committed
310
class ResnetBlock2D(nn.Module):
Patrick von Platen's avatar
Patrick von Platen committed
311
312
313
314
315
316
317
318
319
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
Patrick von Platen's avatar
Patrick von Platen committed
320
        groups_out=None,
Patrick von Platen's avatar
Patrick von Platen committed
321
322
323
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
Patrick von Platen's avatar
Patrick von Platen committed
324
        time_embedding_norm="default",
Patrick von Platen's avatar
Patrick von Platen committed
325
        kernel=None,
Patrick von Platen's avatar
Patrick von Platen committed
326
327
        output_scale_factor=1.0,
        use_nin_shortcut=None,
Patrick von Platen's avatar
Patrick von Platen committed
328
329
        up=False,
        down=False,
Patrick von Platen's avatar
Patrick von Platen committed
330
        overwrite_for_grad_tts=False,
Patrick von Platen's avatar
up  
Patrick von Platen committed
331
        overwrite_for_ldm=False,
Patrick von Platen's avatar
Patrick von Platen committed
332
        overwrite_for_glide=False,
Patrick von Platen's avatar
Patrick von Platen committed
333
        overwrite_for_score_vde=False,
Patrick von Platen's avatar
Patrick von Platen committed
334
    ):
335
336
337
338
339
340
        super().__init__()
        self.pre_norm = pre_norm
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
Patrick von Platen's avatar
Patrick von Platen committed
341
342
343
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
Patrick von Platen's avatar
Patrick von Platen committed
344
345
346
347
348
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

349
        if self.pre_norm:
350
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
351
        else:
352
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
353
354

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
355

356
        if time_embedding_norm == "default" and temb_channels > 0:
Patrick von Platen's avatar
Patrick von Platen committed
357
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
358
        elif time_embedding_norm == "scale_shift" and temb_channels > 0:
Patrick von Platen's avatar
Patrick von Platen committed
359
360
            self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)

361
        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
362
363
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
Patrick von Platen's avatar
up  
Patrick von Platen committed
364

365
        if non_linearity == "swish":
366
            self.nonlinearity = lambda x: F.silu(x)
367
368
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
Patrick von Platen's avatar
up  
Patrick von Platen committed
369
370
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()
371

Patrick von Platen's avatar
Patrick von Platen committed
372
        self.upsample = self.downsample = None
373
374
375
376
377
378
379
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
                self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
380
                self.upsample = Upsample2D(in_channels, use_conv=False)
381
382
383
384
385
386
387
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
                self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
388
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
Patrick von Platen's avatar
Patrick von Platen committed
389

390
        self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
Patrick von Platen's avatar
Patrick von Platen committed
391

392
        self.nin_shortcut = None
Patrick von Platen's avatar
Patrick von Platen committed
393
        if self.use_nin_shortcut:
Patrick von Platen's avatar
Patrick von Platen committed
394
            self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
395

Patrick von Platen's avatar
Patrick von Platen committed
396
        # TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
397
        self.is_overwritten = False
Patrick von Platen's avatar
Patrick von Platen committed
398
        self.overwrite_for_glide = overwrite_for_glide
399
        self.overwrite_for_grad_tts = overwrite_for_grad_tts
Patrick von Platen's avatar
Patrick von Platen committed
400
        self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
Patrick von Platen's avatar
Patrick von Platen committed
401
        self.overwrite_for_score_vde = overwrite_for_score_vde
402
403
404
405
406
407
408
409
410
411
412
413
414
        if self.overwrite_for_grad_tts:
            dim = in_channels
            dim_out = out_channels
            time_emb_dim = temb_channels
            self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
            self.pre_norm = pre_norm

            self.block1 = Block(dim, dim_out, groups=groups)
            self.block2 = Block(dim_out, dim_out, groups=groups)
            if dim != dim_out:
                self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
            else:
                self.res_conv = torch.nn.Identity()
Patrick von Platen's avatar
up  
Patrick von Platen committed
415
416
417
418
        elif self.overwrite_for_ldm:
            channels = in_channels
            emb_channels = temb_channels
            use_scale_shift_norm = False
Patrick von Platen's avatar
Patrick von Platen committed
419
            non_linearity = "silu"
Patrick von Platen's avatar
up  
Patrick von Platen committed
420
421
422
423

            self.in_layers = nn.Sequential(
                normalization(channels, swish=1.0),
                nn.Identity(),
424
                nn.Conv2d(channels, self.out_channels, 3, padding=1),
Patrick von Platen's avatar
up  
Patrick von Platen committed
425
426
427
428
429
            )
            self.emb_layers = nn.Sequential(
                nn.SiLU(),
                linear(
                    emb_channels,
Patrick von Platen's avatar
Patrick von Platen committed
430
                    2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
Patrick von Platen's avatar
up  
Patrick von Platen committed
431
432
433
434
435
436
                ),
            )
            self.out_layers = nn.Sequential(
                normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
                nn.SiLU() if use_scale_shift_norm else nn.Identity(),
                nn.Dropout(p=dropout),
437
                zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
Patrick von Platen's avatar
up  
Patrick von Platen committed
438
439
440
441
            )
            if self.out_channels == in_channels:
                self.skip_connection = nn.Identity()
            else:
442
                self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
Patrick von Platen's avatar
Patrick von Platen committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        elif self.overwrite_for_score_vde:
            in_ch = in_channels
            out_ch = out_channels

            eps = 1e-6
            num_groups = min(in_ch // 4, 32)
            num_groups_out = min(out_ch // 4, 32)
            temb_dim = temb_channels

            self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
            self.up = up
            self.down = down
            self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
            if temb_dim is not None:
                self.Dense_0 = nn.Linear(temb_dim, out_ch)
                self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
                nn.init.zeros_(self.Dense_0.bias)

            self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
            self.Dropout_0 = nn.Dropout(dropout)
            self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1)
            if in_ch != out_ch or up or down:
                # 1x1 convolution with DDPM initialization.
                self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)

            self.in_ch = in_ch
            self.out_ch = out_ch

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    def set_weights_grad_tts(self):
        self.conv1.weight.data = self.block1.block[0].weight.data
        self.conv1.bias.data = self.block1.block[0].bias.data
        self.norm1.weight.data = self.block1.block[1].weight.data
        self.norm1.bias.data = self.block1.block[1].bias.data

        self.conv2.weight.data = self.block2.block[0].weight.data
        self.conv2.bias.data = self.block2.block[0].bias.data
        self.norm2.weight.data = self.block2.block[1].weight.data
        self.norm2.bias.data = self.block2.block[1].bias.data

        self.temb_proj.weight.data = self.mlp[1].weight.data
        self.temb_proj.bias.data = self.mlp[1].bias.data

        if self.in_channels != self.out_channels:
            self.nin_shortcut.weight.data = self.res_conv.weight.data
            self.nin_shortcut.bias.data = self.res_conv.bias.data

Patrick von Platen's avatar
up  
Patrick von Platen committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    def set_weights_ldm(self):
        self.norm1.weight.data = self.in_layers[0].weight.data
        self.norm1.bias.data = self.in_layers[0].bias.data

        self.conv1.weight.data = self.in_layers[-1].weight.data
        self.conv1.bias.data = self.in_layers[-1].bias.data

        self.temb_proj.weight.data = self.emb_layers[-1].weight.data
        self.temb_proj.bias.data = self.emb_layers[-1].bias.data

        self.norm2.weight.data = self.out_layers[0].weight.data
        self.norm2.bias.data = self.out_layers[0].bias.data

        self.conv2.weight.data = self.out_layers[-1].weight.data
        self.conv2.bias.data = self.out_layers[-1].bias.data

        if self.in_channels != self.out_channels:
            self.nin_shortcut.weight.data = self.skip_connection.weight.data
            self.nin_shortcut.bias.data = self.skip_connection.bias.data

Patrick von Platen's avatar
Patrick von Platen committed
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    def set_weights_score_vde(self):
        self.conv1.weight.data = self.Conv_0.weight.data
        self.conv1.bias.data = self.Conv_0.bias.data
        self.norm1.weight.data = self.GroupNorm_0.weight.data
        self.norm1.bias.data = self.GroupNorm_0.bias.data

        self.conv2.weight.data = self.Conv_1.weight.data
        self.conv2.bias.data = self.Conv_1.bias.data
        self.norm2.weight.data = self.GroupNorm_1.weight.data
        self.norm2.bias.data = self.GroupNorm_1.bias.data

        self.temb_proj.weight.data = self.Dense_0.weight.data
        self.temb_proj.bias.data = self.Dense_0.bias.data

        if self.in_channels != self.out_channels or self.up or self.down:
            self.nin_shortcut.weight.data = self.Conv_2.weight.data
            self.nin_shortcut.bias.data = self.Conv_2.bias.data

Patrick von Platen's avatar
up  
Patrick von Platen committed
527
    def forward(self, x, temb, mask=1.0):
Patrick von Platen's avatar
Patrick von Platen committed
528
529
        # TODO(Patrick) eventually this class should be split into multiple classes
        # too many if else statements
530
531
532
        if self.overwrite_for_grad_tts and not self.is_overwritten:
            self.set_weights_grad_tts()
            self.is_overwritten = True
Patrick von Platen's avatar
up  
Patrick von Platen committed
533
534
535
        elif self.overwrite_for_ldm and not self.is_overwritten:
            self.set_weights_ldm()
            self.is_overwritten = True
Patrick von Platen's avatar
Patrick von Platen committed
536
537
538
        elif self.overwrite_for_score_vde and not self.is_overwritten:
            self.set_weights_score_vde()
            self.is_overwritten = True
539
540

        h = x
Patrick von Platen's avatar
up  
Patrick von Platen committed
541
        h = h * mask
542
543
544
545
        if self.pre_norm:
            h = self.norm1(h)
            h = self.nonlinearity(h)

Patrick von Platen's avatar
Patrick von Platen committed
546
547
548
549
550
551
        if self.upsample is not None:
            x = self.upsample(x)
            h = self.upsample(h)
        elif self.downsample is not None:
            x = self.downsample(x)
            h = self.downsample(h)
Patrick von Platen's avatar
Patrick von Platen committed
552

553
554
555
556
557
        h = self.conv1(h)

        if not self.pre_norm:
            h = self.norm1(h)
            h = self.nonlinearity(h)
Patrick von Platen's avatar
up  
Patrick von Platen committed
558
        h = h * mask
559

560
561
562
563
        if temb is not None:
            temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
        else:
            temb = 0
Patrick von Platen's avatar
Patrick von Platen committed
564

Patrick von Platen's avatar
Patrick von Platen committed
565
566
        if self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
567
568

            h = self.norm2(h)
Patrick von Platen's avatar
Patrick von Platen committed
569
            h = h + h * scale + shift
570
            h = self.nonlinearity(h)
Patrick von Platen's avatar
Patrick von Platen committed
571
572
573
574
575
576
        elif self.time_embedding_norm == "default":
            h = h + temb
            h = h * mask
            if self.pre_norm:
                h = self.norm2(h)
                h = self.nonlinearity(h)
577
578
579
580
581
582
583

        h = self.dropout(h)
        h = self.conv2(h)

        if not self.pre_norm:
            h = self.norm2(h)
            h = self.nonlinearity(h)
Patrick von Platen's avatar
up  
Patrick von Platen committed
584
        h = h * mask
585

Patrick von Platen's avatar
up  
Patrick von Platen committed
586
        x = x * mask
Patrick von Platen's avatar
Patrick von Platen committed
587
        if self.nin_shortcut is not None:
Patrick von Platen's avatar
Patrick von Platen committed
588
            x = self.nin_shortcut(x)
589

590
        return (x + h) / self.output_scale_factor
591
592


Patrick von Platen's avatar
finish  
Patrick von Platen committed
593
# TODO(Patrick) - just there to convert the weights; can delete afterward
594
595
596
597
598
class Block(torch.nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super(Block, self).__init__()
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
Patrick von Platen's avatar
Patrick von Platen committed
599
600
601
        )


602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
        super().__init__()

        self.blocks = nn.ModuleList(
            [
                Conv1dBlock(inp_channels, out_channels, kernel_size),
                Conv1dBlock(out_channels, out_channels, kernel_size),
            ]
        )

        self.time_mlp = nn.Sequential(
            nn.Mish(),
            nn.Linear(embed_dim, out_channels),
            RearrangeDim(),
            #            Rearrange("batch t -> batch t 1"),
        )

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

    def forward(self, x, t):
        """
        x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
        out_channels x horizon ]
        """
        out = self.blocks[0](x) + self.time_mlp(t)
        out = self.blocks[1](out)
        return out + self.residual_conv(x)


Patrick von Platen's avatar
Patrick von Platen committed
635
636
637
638
# HELPER Modules


def normalization(channels, swish=0.0):
639
    """
Patrick von Platen's avatar
Patrick von Platen committed
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
    Make a standard normalization layer, with an optional swish activation.

    :param channels: number of input channels. :return: an nn.Module for normalization.
    """
    return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)


class GroupNorm32(nn.GroupNorm):
    def __init__(self, num_groups, num_channels, swish, eps=1e-5):
        super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
        self.swish = swish

    def forward(self, x):
        y = super().forward(x.float()).to(x.dtype)
        if self.swish == 1.0:
            y = F.silu(y)
        elif self.swish:
            y = y * F.sigmoid(y * float(self.swish))
        return y


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
671
    """
Patrick von Platen's avatar
Patrick von Platen committed
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
    for p in module.parameters():
        p.detach().zero_()
    return module


class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
688
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
689
690
691
692
693
694
695
696
697
698

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            RearrangeDim(),
            #            Rearrange("batch channels horizon -> batch channels 1 horizon"),
            nn.GroupNorm(n_groups, out_channels),
            RearrangeDim(),
            #            Rearrange("batch channels 1 horizon -> batch channels horizon"),
            nn.Mish(),
        )
699
700

    def forward(self, x):
Patrick von Platen's avatar
Patrick von Platen committed
701
702
703
704
705
706
707
708
709
710
711
712
713
714
        return self.block(x)


class RearrangeDim(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, tensor):
        if len(tensor.shape) == 2:
            return tensor[:, :, None]
        if len(tensor.shape) == 3:
            return tensor[:, :, None, :]
        elif len(tensor.shape) == 4:
            return tensor[:, :, 0, :]
715
        else:
Patrick von Platen's avatar
Patrick von Platen committed
716
717
718
            raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


patil-suraj's avatar
patil-suraj committed
719
720
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
    """nXn convolution with DDPM initialization."""
patil-suraj's avatar
style  
patil-suraj committed
721
    conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
patil-suraj's avatar
patil-suraj committed
722
    conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
Patrick von Platen's avatar
Patrick von Platen committed
723
724
725
726
    nn.init.zeros_(conv.bias)
    return conv


patil-suraj's avatar
patil-suraj committed
727
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
Patrick von Platen's avatar
Patrick von Platen committed
728
    """Ported from JAX."""
patil-suraj's avatar
patil-suraj committed
729
    scale = 1e-10 if scale == 0 else scale
Patrick von Platen's avatar
Patrick von Platen committed
730
731
732
733
734
735
736
737
738

    def _compute_fans(shape, in_axis=1, out_axis=0):
        receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
        fan_in = shape[in_axis] * receptive_field_size
        fan_out = shape[out_axis] * receptive_field_size
        return fan_in, fan_out

    def init(shape, dtype=dtype, device=device):
        fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
patil-suraj's avatar
patil-suraj committed
739
        denominator = (fan_in + fan_out) / 2
Patrick von Platen's avatar
Patrick von Platen committed
740
        variance = scale / denominator
patil-suraj's avatar
patil-suraj committed
741
        return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
742

Patrick von Platen's avatar
Patrick von Platen committed
743
    return init
744
745


Patrick von Platen's avatar
Patrick von Platen committed
746
747
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
748
749


Patrick von Platen's avatar
Patrick von Platen committed
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)


def upsample_2d(x, k=None, factor=2, gain=1):
789
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
    multiple of the upsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]`
    """
    assert isinstance(factor, int) and factor >= 1
    if k is None:
        k = [1] * factor
    k = _setup_kernel(k) * (gain * (factor**2))
    p = k.shape[0] - factor
    return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))


def downsample_2d(x, k=None, factor=2, gain=1):
814
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
        factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]`
    """

    assert isinstance(factor, int) and factor >= 1
    if k is None:
        k = [1] * factor
    k = _setup_kernel(k) * gain
    p = k.shape[0] - factor
    return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))


def _setup_kernel(k):
    k = np.asarray(k, dtype=np.float32)
    if k.ndim == 1:
        k = np.outer(k, k)
    k /= np.sum(k)
    assert k.ndim == 2
    assert k.shape[0] == k.shape[1]
    return k