"src/vscode:/vscode.git/clone" did not exist on "d04cd95012e0e9dd73169c44ae16f97f72b7eac1"
resnet.py 29.4 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
2
from typing import Optional
Patrick von Platen's avatar
Patrick von Platen committed
3

4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F

8
9
from .attention import AdaGroupNorm

10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class Upsample1D(nn.Module):
    """
    An upsampling layer with an optional convolution.

    Parameters:
            channels: channels in the inputs and outputs.
            use_conv: a bool determining if a convolution is applied.
            use_conv_transpose:
            out_channels:
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)

        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

        if self.use_conv:
            x = self.conv(x)

        return x


class Downsample1D(nn.Module):
    """
    A downsampling layer with an optional convolution.

    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        out_channels:
        padding:
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.conv(x)


80
class Upsample2D(nn.Module):
81
82
83
    """
    An upsampling layer with an optional convolution.

84
85
86
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
87
88
        use_conv_transpose:
        out_channels:
89
90
    """

91
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
92
93
94
95
96
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
97
        self.name = name
98

patil-suraj's avatar
patil-suraj committed
99
        conv = None
100
        if use_conv_transpose:
101
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
102
        elif use_conv:
103
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
104

105
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
106
107
108
109
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
110

111
    def forward(self, hidden_states, output_size=None):
112
        assert hidden_states.shape[1] == self.channels
113

114
        if self.use_conv_transpose:
115
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
116

117
118
119
120
121
122
123
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

124
125
126
127
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

128
129
130
131
132
133
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
134

135
136
137
138
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

139
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
140
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
141
            if self.name == "conv":
142
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
143
            else:
144
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
145

146
        return hidden_states
147
148


149
class Downsample2D(nn.Module):
150
151
152
    """
    A downsampling layer with an optional convolution.

153
154
155
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
156
157
        out_channels:
        padding:
158
159
    """

160
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
161
162
163
164
165
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
166
        stride = 2
patil-suraj's avatar
patil-suraj committed
167
168
        self.name = name

169
        if use_conv:
170
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
171
172
        else:
            assert self.channels == self.out_channels
173
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
174

175
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
176
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
177
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
178
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
179
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
180
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
181
        else:
Patrick von Platen's avatar
Patrick von Platen committed
182
            self.conv = conv
183

184
185
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
186
        if self.use_conv and self.padding == 0:
187
            pad = (0, 1, 0, 1)
188
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
189

190
191
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
192

193
        return hidden_states
194
195
196
197
198
199
200
201
202
203
204
205


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

206
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
207
208
209
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
210
211
212
213
214
215
216
217
218
219
220
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
221
222

        Returns:
223
224
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
225
226
227
228
229
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
230
231
        if kernel is None:
            kernel = [1] * factor
232
233

        # setup kernel
234
        kernel = torch.tensor(kernel, dtype=torch.float32)
235
        if kernel.ndim == 1:
236
237
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
238

239
        kernel = kernel * (gain * (factor**2))
240
241

        if self.use_conv:
242
243
244
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
245

246
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
247
248
249

            stride = (factor, factor)
            # Determine data dimensions.
250
251
252
253
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
254
            output_padding = (
255
256
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
257
258
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
259
            num_groups = hidden_states.shape[1] // inC
260
261

            # Transpose weights.
262
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
263
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
264
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
265

266
267
268
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
269

270
271
272
273
274
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
275
        else:
276
277
278
279
280
281
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
282
283
            )

284
        return output
285

286
    def forward(self, hidden_states):
287
        if self.use_conv:
288
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
289
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
290
        else:
291
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
292

293
        return height
294
295
296
297
298
299
300


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
301
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
302
303
304
305
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

306
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
307
        """Fused `Conv2d()` followed by `downsample_2d()`.
308
309
310
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
311
312

        Args:
313
314
315
316
317
318
319
320
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
321
322

        Returns:
323
324
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
325
        """
326

327
        assert isinstance(factor, int) and factor >= 1
328
329
        if kernel is None:
            kernel = [1] * factor
330

331
        # setup kernel
332
        kernel = torch.tensor(kernel, dtype=torch.float32)
333
        if kernel.ndim == 1:
334
335
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
336

337
        kernel = kernel * gain
338

339
        if self.use_conv:
340
            _, _, convH, convW = weight.shape
341
342
343
344
345
346
347
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
348
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
349
        else:
350
            pad_value = kernel.shape[0] - factor
351
            output = upfirdn2d_native(
352
353
354
355
356
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
357

358
        return output
359

360
    def forward(self, hidden_states):
361
        if self.use_conv:
362
363
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
364
        else:
365
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
366

367
        return hidden_states
368
369


370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    def forward(self, x):
        x = F.pad(x, (self.pad,) * 4, self.pad_mode)
        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(x.shape[1], device=x.device)
        weight[indices, indices] = self.kernel.to(weight)
        return F.conv2d(x, weight, stride=2)


class KUpsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    def forward(self, x):
        x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(x.shape[1], device=x.device)
        weight[indices, indices] = self.kernel.to(weight)
        return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)


403
class ResnetBlock2D(nn.Module):
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
        kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

434
435
436
437
438
439
440
441
442
443
444
445
446
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
447
        time_embedding_norm="default",  # default, scale_shift, ada_group
448
449
        kernel=None,
        output_scale_factor=1.0,
450
        use_in_shortcut=None,
451
452
        up=False,
        down=False,
453
454
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
455
456
457
458
459
460
461
462
463
464
465
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
466
        self.time_embedding_norm = time_embedding_norm
467
468
469
470

        if groups_out is None:
            groups_out = groups

471
472
473
474
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
475
476
477

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

478
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
479
            if self.time_embedding_norm == "default":
480
                self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
481
            elif self.time_embedding_norm == "scale_shift":
482
483
484
                self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
            elif self.time_embedding_norm == "ada_group":
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
485
486
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
487
488
        else:
            self.time_emb_proj = None
489

490
491
492
493
494
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

495
        self.dropout = torch.nn.Dropout(dropout)
496
497
        conv_2d_out_channels = conv_2d_out_channels or out_channels
        self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
498
499
500
501

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
502
            self.nonlinearity = nn.Mish()
503
504
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()
505
506
        elif non_linearity == "gelu":
            self.nonlinearity = nn.GELU()
507
508
509
510
511

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
512
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
513
514
515
516
517
518
519
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
520
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
521
522
523
524
525
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

526
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
527
528

        self.conv_shortcut = None
529
        if self.use_in_shortcut:
530
531
532
            self.conv_shortcut = torch.nn.Conv2d(
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
533

534
535
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
536

537
538
539
540
541
        if self.time_embedding_norm == "ada_group":
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

542
        hidden_states = self.nonlinearity(hidden_states)
543
544

        if self.upsample is not None:
545
546
547
548
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
549
            input_tensor = self.upsample(input_tensor)
550
            hidden_states = self.upsample(hidden_states)
551
        elif self.downsample is not None:
552
            input_tensor = self.downsample(input_tensor)
553
            hidden_states = self.downsample(hidden_states)
554

555
        hidden_states = self.conv1(hidden_states)
556

557
        if self.time_emb_proj is not None:
558
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
Will Berman's avatar
Will Berman committed
559
560

        if temb is not None and self.time_embedding_norm == "default":
561
            hidden_states = hidden_states + temb
562

563
564
565
566
        if self.time_embedding_norm == "ada_group":
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
567
568
569
570
571

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

572
        hidden_states = self.nonlinearity(hidden_states)
573

574
575
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
576
577

        if self.conv_shortcut is not None:
578
            input_tensor = self.conv_shortcut(input_tensor)
579

580
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
581

582
        return output_tensor
583

Patrick von Platen's avatar
Patrick von Platen committed
584
585

class Mish(torch.nn.Module):
586
587
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
588
589


590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
# unet_rl.py
def rearrange_dims(tensor):
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

    def forward(self, x):
        x = self.conv1d(x)
        x = rearrange_dims(x)
        x = self.group_norm(x)
        x = rearrange_dims(x)
        x = self.mish(x)
        return x


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

    def forward(self, x, t):
        """
        Args:
            x : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
        out = self.conv_in(x) + rearrange_dims(t)
        out = self.conv_out(out)
        return out + self.residual_conv(x)


653
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
654
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
655
656
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
657
658
659
660
661
662
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
663
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
664
665
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
666
667

    Returns:
668
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
669
670
    """
    assert isinstance(factor, int) and factor >= 1
671
672
    if kernel is None:
        kernel = [1] * factor
673

674
    kernel = torch.tensor(kernel, dtype=torch.float32)
675
    if kernel.ndim == 1:
676
677
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
678

679
    kernel = kernel * (gain * (factor**2))
680
    pad_value = kernel.shape[0] - factor
681
    output = upfirdn2d_native(
682
683
684
685
686
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
687
    return output
Patrick von Platen's avatar
Patrick von Platen committed
688
689


690
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
691
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
692
693
694
695
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
696
697
698

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
699
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
700
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
701
702
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
703
704

    Returns:
705
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
706
707
708
    """

    assert isinstance(factor, int) and factor >= 1
709
710
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
711

712
    kernel = torch.tensor(kernel, dtype=torch.float32)
713
    if kernel.ndim == 1:
714
715
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
716

717
    kernel = kernel * gain
718
    pad_value = kernel.shape[0] - factor
719
    output = upfirdn2d_native(
720
721
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
722
    return output
723
724


725
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
726
727
728
729
730
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

731
732
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
733

734
    _, in_h, in_w, minor = tensor.shape
735
736
    kernel_h, kernel_w = kernel.shape

737
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
738
739
740
741
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
742
    out = out.to(tensor.device)  # Move back to mps if necessary
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)