resnet.py 19.5 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
class Upsample2D(nn.Module):
9
10
11
    """
    An upsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
12
13
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
14
15
16
                 upsampling occurs in the inner-two dimensions.
    """

17
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
18
19
20
21
22
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
23
        self.name = name
24

patil-suraj's avatar
patil-suraj committed
25
        conv = None
26
        if use_conv_transpose:
27
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
28
        elif use_conv:
29
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
30

31
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
32
33
34
35
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
36

37
    def forward(self, hidden_states, output_size=None):
38
        assert hidden_states.shape[1] == self.channels
39

40
        if self.use_conv_transpose:
41
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
42

43
44
45
46
47
48
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
49

50
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
51
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
52
            if self.name == "conv":
53
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
54
            else:
55
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
56

57
        return hidden_states
58
59


60
class Downsample2D(nn.Module):
61
62
63
    """
    A downsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
64
65
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
66
67
68
                 downsampling occurs in the inner-two dimensions.
    """

69
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
70
71
72
73
74
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
75
        stride = 2
patil-suraj's avatar
patil-suraj committed
76
77
        self.name = name

78
        if use_conv:
79
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
80
81
        else:
            assert self.channels == self.out_channels
82
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
83

84
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
85
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
86
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
87
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
88
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
89
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
90
        else:
Patrick von Platen's avatar
Patrick von Platen committed
91
            self.conv = conv
92

93
94
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
95
        if self.use_conv and self.padding == 0:
96
            pad = (0, 1, 0, 1)
97
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
98

99
100
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
101

102
        return hidden_states
103
104
105
106
107
108
109
110
111
112
113
114


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

115
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
116
117
118
119
120
121
122
123
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
            C]`.
124
        weight: Weight tensor of the shape `[filterH, filterW, inChannels,
125
            outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
126
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
127
128
129
130
131
132
133
134
135
136
137
            (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

        Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
        `x`.
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
138
139
        if kernel is None:
            kernel = [1] * factor
140
141

        # setup kernel
142
        kernel = torch.tensor(kernel, dtype=torch.float32)
143
        if kernel.ndim == 1:
144
145
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
146

147
        kernel = kernel * (gain * (factor**2))
148
149

        if self.use_conv:
150
151
152
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
153

154
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
155
156
157

            stride = (factor, factor)
            # Determine data dimensions.
158
159
160
161
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
162
            output_padding = (
163
164
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
165
166
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
167
            inC = weight.shape[1]
168
            num_groups = hidden_states.shape[1] // inC
169
170

            # Transpose weights.
171
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
172
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
173
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
174

175
176
177
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
178

179
180
181
182
183
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
184
        else:
185
186
187
188
189
190
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
191
192
            )

193
        return output
194

195
    def forward(self, hidden_states):
196
        if self.use_conv:
197
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
198
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
199
        else:
200
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
201

202
        return height
203
204
205
206
207
208
209


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
210
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
211
212
213
214
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

215
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        """Fused `Conv2d()` followed by `downsample_2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
            x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
            filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
            numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
            Scaling factor for signal magnitude (default: 1.0).

        Returns:
            Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
            datatype as `x`.
        """
232

233
        assert isinstance(factor, int) and factor >= 1
234
235
        if kernel is None:
            kernel = [1] * factor
236

237
        # setup kernel
238
        kernel = torch.tensor(kernel, dtype=torch.float32)
239
        if kernel.ndim == 1:
240
241
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
242

243
        kernel = kernel * gain
244

245
        if self.use_conv:
246
            _, _, convH, convW = weight.shape
247
248
249
250
251
252
253
254
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
            hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
255
        else:
256
257
258
259
260
261
262
            pad_value = kernel.shape[0] - factor
            hidden_states = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
263

264
        return hidden_states
265

266
    def forward(self, hidden_states):
267
        if self.use_conv:
268
269
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
270
        else:
271
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
272

273
        return hidden_states
274
275


276
class ResnetBlock2D(nn.Module):
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
293
        use_in_shortcut=None,
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

316
317
318
319
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
336
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
337
338
339
340
341
342
343
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
344
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
345
346
347
348
349
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

350
        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
351
352

        self.conv_shortcut = None
353
        if self.use_in_shortcut:
354
355
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

356
357
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
358

359
        hidden_states = self.norm1(hidden_states)
360
        hidden_states = self.nonlinearity(hidden_states)
361
362

        if self.upsample is not None:
363
            input_tensor = self.upsample(input_tensor)
364
            hidden_states = self.upsample(hidden_states)
365
        elif self.downsample is not None:
366
            input_tensor = self.downsample(input_tensor)
367
            hidden_states = self.downsample(hidden_states)
368

369
        hidden_states = self.conv1(hidden_states)
370
371
372

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
373
            hidden_states = hidden_states + temb
374

375
        hidden_states = self.norm2(hidden_states)
376
        hidden_states = self.nonlinearity(hidden_states)
377

378
379
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
380
381

        if self.conv_shortcut is not None:
382
            input_tensor = self.conv_shortcut(input_tensor)
383

384
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
385

386
        return output_tensor
387

Patrick von Platen's avatar
Patrick von Platen committed
388
389

class Mish(torch.nn.Module):
390
391
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
392
393


394
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
395
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
    multiple of the upsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]`
    """
    assert isinstance(factor, int) and factor >= 1
412
413
    if kernel is None:
        kernel = [1] * factor
414

415
    kernel = torch.tensor(kernel, dtype=torch.float32)
416
    if kernel.ndim == 1:
417
418
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
419

420
    kernel = kernel * (gain * (factor**2))
421
422
423
424
425
426
427
    pad_value = kernel.shape[0] - factor
    return upfirdn2d_native(
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
Patrick von Platen's avatar
Patrick von Platen committed
428
429


430
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
431
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
432
433
434
435
436
437
438
439

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
440
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
441
442
443
444
445
446
447
448
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
        factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]`
    """

    assert isinstance(factor, int) and factor >= 1
449
450
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
451

452
    kernel = torch.tensor(kernel, dtype=torch.float32)
453
    if kernel.ndim == 1:
454
455
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
456

457
    kernel = kernel * gain
458
459
460
461
    pad_value = kernel.shape[0] - factor
    return upfirdn2d_native(
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
462
463
464
465
466
467
468
469
470
471


def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)
472
    # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
473
474
475
476
477

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
478
479
480
481

    # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
    if input.device.type == "mps":
        out = out.to("cpu")
482
483
484
485
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
486
    out = out.to(input.device)  # Move back to mps if necessary
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)