resnet.py 24.2 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class Upsample1D(nn.Module):
    """
    An upsampling layer with an optional convolution.

    Parameters:
            channels: channels in the inputs and outputs.
            use_conv: a bool determining if a convolution is applied.
            use_conv_transpose:
            out_channels:
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)

        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

        if self.use_conv:
            x = self.conv(x)

        return x


class Downsample1D(nn.Module):
    """
    A downsampling layer with an optional convolution.

    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        out_channels:
        padding:
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.conv(x)


77
class Upsample2D(nn.Module):
78
79
80
    """
    An upsampling layer with an optional convolution.

81
82
83
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
84
85
        use_conv_transpose:
        out_channels:
86
87
    """

88
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
89
90
91
92
93
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
94
        self.name = name
95

patil-suraj's avatar
patil-suraj committed
96
        conv = None
97
        if use_conv_transpose:
98
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
99
        elif use_conv:
100
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
101

102
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
103
104
105
106
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
107

108
    def forward(self, hidden_states, output_size=None):
109
        assert hidden_states.shape[1] == self.channels
110

111
        if self.use_conv_transpose:
112
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
113

114
115
116
117
118
119
120
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

121
122
123
124
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

125
126
127
128
129
130
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
131

132
133
134
135
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

136
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
137
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
138
            if self.name == "conv":
139
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
140
            else:
141
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
142

143
        return hidden_states
144
145


146
class Downsample2D(nn.Module):
147
148
149
    """
    A downsampling layer with an optional convolution.

150
151
152
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
153
154
        out_channels:
        padding:
155
156
    """

157
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
158
159
160
161
162
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
163
        stride = 2
patil-suraj's avatar
patil-suraj committed
164
165
        self.name = name

166
        if use_conv:
167
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
168
169
        else:
            assert self.channels == self.out_channels
170
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
171

172
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
173
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
174
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
175
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
176
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
177
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
178
        else:
Patrick von Platen's avatar
Patrick von Platen committed
179
            self.conv = conv
180

181
182
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
183
        if self.use_conv and self.padding == 0:
184
            pad = (0, 1, 0, 1)
185
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
186

187
188
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
189

190
        return hidden_states
191
192
193
194
195
196
197
198
199
200
201
202


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

203
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
204
205
206
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
207
208
209
210
211
212
213
214
215
216
217
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
218
219

        Returns:
220
221
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
222
223
224
225
226
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
227
228
        if kernel is None:
            kernel = [1] * factor
229
230

        # setup kernel
231
        kernel = torch.tensor(kernel, dtype=torch.float32)
232
        if kernel.ndim == 1:
233
234
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
235

236
        kernel = kernel * (gain * (factor**2))
237
238

        if self.use_conv:
239
240
241
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
242

243
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
244
245
246

            stride = (factor, factor)
            # Determine data dimensions.
247
248
249
250
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
251
            output_padding = (
252
253
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
254
255
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
256
            num_groups = hidden_states.shape[1] // inC
257
258

            # Transpose weights.
259
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
260
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
261
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
262

263
264
265
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
266

267
268
269
270
271
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
272
        else:
273
274
275
276
277
278
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
279
280
            )

281
        return output
282

283
    def forward(self, hidden_states):
284
        if self.use_conv:
285
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
286
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
287
        else:
288
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
289

290
        return height
291
292
293
294
295
296
297


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
298
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
299
300
301
302
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

303
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
304
        """Fused `Conv2d()` followed by `downsample_2d()`.
305
306
307
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
308
309

        Args:
310
311
312
313
314
315
316
317
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
318
319

        Returns:
320
321
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
322
        """
323

324
        assert isinstance(factor, int) and factor >= 1
325
326
        if kernel is None:
            kernel = [1] * factor
327

328
        # setup kernel
329
        kernel = torch.tensor(kernel, dtype=torch.float32)
330
        if kernel.ndim == 1:
331
332
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
333

334
        kernel = kernel * gain
335

336
        if self.use_conv:
337
            _, _, convH, convW = weight.shape
338
339
340
341
342
343
344
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
345
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
346
        else:
347
            pad_value = kernel.shape[0] - factor
348
            output = upfirdn2d_native(
349
350
351
352
353
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
354

355
        return output
356

357
    def forward(self, hidden_states):
358
        if self.use_conv:
359
360
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
361
        else:
362
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
363

364
        return hidden_states
365
366


367
class ResnetBlock2D(nn.Module):
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
384
        use_in_shortcut=None,
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

407
408
409
410
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
427
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
428
429
430
431
432
433
434
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
435
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
436
437
438
439
440
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

441
        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
442
443

        self.conv_shortcut = None
444
        if self.use_in_shortcut:
445
446
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

447
448
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
449

450
        hidden_states = self.norm1(hidden_states)
451
        hidden_states = self.nonlinearity(hidden_states)
452
453

        if self.upsample is not None:
454
455
456
457
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
458
            input_tensor = self.upsample(input_tensor)
459
            hidden_states = self.upsample(hidden_states)
460
        elif self.downsample is not None:
461
            input_tensor = self.downsample(input_tensor)
462
            hidden_states = self.downsample(hidden_states)
463

464
        hidden_states = self.conv1(hidden_states)
465
466
467

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
468
            hidden_states = hidden_states + temb
469

470
        hidden_states = self.norm2(hidden_states)
471
        hidden_states = self.nonlinearity(hidden_states)
472

473
474
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
475
476

        if self.conv_shortcut is not None:
477
            input_tensor = self.conv_shortcut(input_tensor)
478

479
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
480

481
        return output_tensor
482

Patrick von Platen's avatar
Patrick von Platen committed
483
484

class Mish(torch.nn.Module):
485
486
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
487
488


489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
# unet_rl.py
def rearrange_dims(tensor):
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

    def forward(self, x):
        x = self.conv1d(x)
        x = rearrange_dims(x)
        x = self.group_norm(x)
        x = rearrange_dims(x)
        x = self.mish(x)
        return x


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

    def forward(self, x, t):
        """
        Args:
            x : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
        out = self.conv_in(x) + rearrange_dims(t)
        out = self.conv_out(out)
        return out + self.residual_conv(x)


552
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
553
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
554
555
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
556
557
558
559
560
561
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
562
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
563
564
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
565
566

    Returns:
567
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
568
569
    """
    assert isinstance(factor, int) and factor >= 1
570
571
    if kernel is None:
        kernel = [1] * factor
572

573
    kernel = torch.tensor(kernel, dtype=torch.float32)
574
    if kernel.ndim == 1:
575
576
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
577

578
    kernel = kernel * (gain * (factor**2))
579
    pad_value = kernel.shape[0] - factor
580
    output = upfirdn2d_native(
581
582
583
584
585
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
586
    return output
Patrick von Platen's avatar
Patrick von Platen committed
587
588


589
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
590
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
591
592
593
594
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
595
596
597

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
598
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
599
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
600
601
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
602
603

    Returns:
604
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
605
606
607
    """

    assert isinstance(factor, int) and factor >= 1
608
609
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
610

611
    kernel = torch.tensor(kernel, dtype=torch.float32)
612
    if kernel.ndim == 1:
613
614
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
615

616
    kernel = kernel * gain
617
    pad_value = kernel.shape[0] - factor
618
    output = upfirdn2d_native(
619
620
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
621
    return output
622
623


624
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
625
626
627
628
629
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

630
631
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
632

633
    _, in_h, in_w, minor = tensor.shape
634
635
    kernel_h, kernel_w = kernel.shape

636
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
637
638
639
640
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
641
    out = out.to(tensor.device)  # Move back to mps if necessary
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)