resnet.py 24.8 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class Upsample1D(nn.Module):
    """
    An upsampling layer with an optional convolution.

    Parameters:
            channels: channels in the inputs and outputs.
            use_conv: a bool determining if a convolution is applied.
            use_conv_transpose:
            out_channels:
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)

        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

        if self.use_conv:
            x = self.conv(x)

        return x


class Downsample1D(nn.Module):
    """
    A downsampling layer with an optional convolution.

    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        out_channels:
        padding:
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.conv(x)


77
class Upsample2D(nn.Module):
78
79
80
    """
    An upsampling layer with an optional convolution.

81
82
83
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
84
85
        use_conv_transpose:
        out_channels:
86
87
    """

88
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
89
90
91
92
93
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
94
        self.name = name
95

patil-suraj's avatar
patil-suraj committed
96
        conv = None
97
        if use_conv_transpose:
98
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
99
        elif use_conv:
100
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
101

102
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
103
104
105
106
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
107

108
    def forward(self, hidden_states, output_size=None):
109
        assert hidden_states.shape[1] == self.channels
110

111
        if self.use_conv_transpose:
112
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
113

114
115
116
117
118
119
120
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

121
122
123
124
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

125
126
127
128
129
130
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
131

132
133
134
135
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

136
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
137
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
138
            if self.name == "conv":
139
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
140
            else:
141
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
142

143
        return hidden_states
144
145


146
class Downsample2D(nn.Module):
147
148
149
    """
    A downsampling layer with an optional convolution.

150
151
152
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
153
154
        out_channels:
        padding:
155
156
    """

157
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
158
159
160
161
162
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
163
        stride = 2
patil-suraj's avatar
patil-suraj committed
164
165
        self.name = name

166
        if use_conv:
167
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
168
169
        else:
            assert self.channels == self.out_channels
170
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
171

172
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
173
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
174
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
175
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
176
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
177
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
178
        else:
Patrick von Platen's avatar
Patrick von Platen committed
179
            self.conv = conv
180

181
182
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
183
        if self.use_conv and self.padding == 0:
184
            pad = (0, 1, 0, 1)
185
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
186

187
188
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
189

190
        return hidden_states
191
192
193
194
195
196
197
198
199
200
201
202


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

203
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
204
205
206
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
207
208
209
210
211
212
213
214
215
216
217
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
218
219

        Returns:
220
221
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
222
223
224
225
226
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
227
228
        if kernel is None:
            kernel = [1] * factor
229
230

        # setup kernel
231
        kernel = torch.tensor(kernel, dtype=torch.float32)
232
        if kernel.ndim == 1:
233
234
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
235

236
        kernel = kernel * (gain * (factor**2))
237
238

        if self.use_conv:
239
240
241
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
242

243
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
244
245
246

            stride = (factor, factor)
            # Determine data dimensions.
247
248
249
250
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
251
            output_padding = (
252
253
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
254
255
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
256
            num_groups = hidden_states.shape[1] // inC
257
258

            # Transpose weights.
259
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
260
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
261
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
262

263
264
265
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
266

267
268
269
270
271
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
272
        else:
273
274
275
276
277
278
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
279
280
            )

281
        return output
282

283
    def forward(self, hidden_states):
284
        if self.use_conv:
285
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
286
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
287
        else:
288
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
289

290
        return height
291
292
293
294
295
296
297


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
298
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
299
300
301
302
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

303
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
304
        """Fused `Conv2d()` followed by `downsample_2d()`.
305
306
307
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
308
309

        Args:
310
311
312
313
314
315
316
317
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
318
319

        Returns:
320
321
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
322
        """
323

324
        assert isinstance(factor, int) and factor >= 1
325
326
        if kernel is None:
            kernel = [1] * factor
327

328
        # setup kernel
329
        kernel = torch.tensor(kernel, dtype=torch.float32)
330
        if kernel.ndim == 1:
331
332
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
333

334
        kernel = kernel * gain
335

336
        if self.use_conv:
337
            _, _, convH, convW = weight.shape
338
339
340
341
342
343
344
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
345
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
346
        else:
347
            pad_value = kernel.shape[0] - factor
348
            output = upfirdn2d_native(
349
350
351
352
353
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
354

355
        return output
356

357
    def forward(self, hidden_states):
358
        if self.use_conv:
359
360
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
361
        else:
362
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
363

364
        return hidden_states
365
366


367
class ResnetBlock2D(nn.Module):
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
384
        use_in_shortcut=None,
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

407
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
408
409
410
411
412
413
414
415
            if self.time_embedding_norm == "default":
                time_emb_proj_out_channels = out_channels
            elif self.time_embedding_norm == "scale_shift":
                time_emb_proj_out_channels = out_channels * 2
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")

            self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
416
417
        else:
            self.time_emb_proj = None
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
434
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
435
436
437
438
439
440
441
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
442
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
443
444
445
446
447
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

448
        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
449
450

        self.conv_shortcut = None
451
        if self.use_in_shortcut:
452
453
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

454
455
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
456

457
        hidden_states = self.norm1(hidden_states)
458
        hidden_states = self.nonlinearity(hidden_states)
459
460

        if self.upsample is not None:
461
462
463
464
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
465
            input_tensor = self.upsample(input_tensor)
466
            hidden_states = self.upsample(hidden_states)
467
        elif self.downsample is not None:
468
            input_tensor = self.downsample(input_tensor)
469
            hidden_states = self.downsample(hidden_states)
470

471
        hidden_states = self.conv1(hidden_states)
472
473
474

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
Will Berman's avatar
Will Berman committed
475
476

        if temb is not None and self.time_embedding_norm == "default":
477
            hidden_states = hidden_states + temb
478

479
        hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
480
481
482
483
484

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

485
        hidden_states = self.nonlinearity(hidden_states)
486

487
488
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
489
490

        if self.conv_shortcut is not None:
491
            input_tensor = self.conv_shortcut(input_tensor)
492

493
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
494

495
        return output_tensor
496

Patrick von Platen's avatar
Patrick von Platen committed
497
498

class Mish(torch.nn.Module):
499
500
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
501
502


503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
# unet_rl.py
def rearrange_dims(tensor):
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

    def forward(self, x):
        x = self.conv1d(x)
        x = rearrange_dims(x)
        x = self.group_norm(x)
        x = rearrange_dims(x)
        x = self.mish(x)
        return x


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

    def forward(self, x, t):
        """
        Args:
            x : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
        out = self.conv_in(x) + rearrange_dims(t)
        out = self.conv_out(out)
        return out + self.residual_conv(x)


566
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
567
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
568
569
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
570
571
572
573
574
575
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
576
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
577
578
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
579
580

    Returns:
581
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
582
583
    """
    assert isinstance(factor, int) and factor >= 1
584
585
    if kernel is None:
        kernel = [1] * factor
586

587
    kernel = torch.tensor(kernel, dtype=torch.float32)
588
    if kernel.ndim == 1:
589
590
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
591

592
    kernel = kernel * (gain * (factor**2))
593
    pad_value = kernel.shape[0] - factor
594
    output = upfirdn2d_native(
595
596
597
598
599
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
600
    return output
Patrick von Platen's avatar
Patrick von Platen committed
601
602


603
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
604
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
605
606
607
608
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
609
610
611

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
612
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
613
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
614
615
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
616
617

    Returns:
618
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
619
620
621
    """

    assert isinstance(factor, int) and factor >= 1
622
623
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
624

625
    kernel = torch.tensor(kernel, dtype=torch.float32)
626
    if kernel.ndim == 1:
627
628
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
629

630
    kernel = kernel * gain
631
    pad_value = kernel.shape[0] - factor
632
    output = upfirdn2d_native(
633
634
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
635
    return output
636
637


638
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
639
640
641
642
643
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

644
645
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
646

647
    _, in_h, in_w, minor = tensor.shape
648
649
    kernel_h, kernel_w = kernel.shape

650
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
651
652
653
654
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
655
    out = out.to(tensor.device)  # Move back to mps if necessary
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)