resnet.py 18.4 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
class Upsample2D(nn.Module):
9
10
11
    """
    An upsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
12
13
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
14
15
16
                 upsampling occurs in the inner-two dimensions.
    """

17
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
18
19
20
21
22
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
23
        self.name = name
24

patil-suraj's avatar
patil-suraj committed
25
        conv = None
26
        if use_conv_transpose:
27
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
28
        elif use_conv:
29
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
30

31
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
32
33
34
35
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
36

37
    def forward(self, hidden_states, output_size=None):
38
        assert hidden_states.shape[1] == self.channels
39

40
        if self.use_conv_transpose:
41
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
42

43
44
45
46
47
48
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
49

50
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
51
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
52
            if self.name == "conv":
53
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
54
            else:
55
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
56

57
        return hidden_states
58
59


60
class Downsample2D(nn.Module):
61
62
63
    """
    A downsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
64
65
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
66
67
68
                 downsampling occurs in the inner-two dimensions.
    """

69
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
70
71
72
73
74
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
75
        stride = 2
patil-suraj's avatar
patil-suraj committed
76
77
        self.name = name

78
        if use_conv:
79
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
80
81
        else:
            assert self.channels == self.out_channels
82
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
83

84
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
85
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
86
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
87
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
88
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
89
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
90
        else:
Patrick von Platen's avatar
Patrick von Platen committed
91
            self.conv = conv
92

93
94
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
95
        if self.use_conv and self.padding == 0:
96
            pad = (0, 1, 0, 1)
97
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
98

99
100
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
101

102
        return hidden_states
103
104
105
106
107
108
109
110
111
112
113
114


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

115
    def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
116
117
118
119
120
121
122
123
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
            C]`.
124
        weight: Weight tensor of the shape `[filterH, filterW, inChannels,
125
            outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
126
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
127
128
129
130
131
132
133
134
135
136
137
            (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

        Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
        `x`.
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
138
139
        if kernel is None:
            kernel = [1] * factor
140
141

        # setup kernel
142
        kernel = torch.tensor(kernel, dtype=torch.float32)
143
        if kernel.ndim == 1:
144
145
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
146

147
        kernel = kernel * (gain * (factor**2))
148
149

        if self.use_conv:
150
151
152
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
153

154
            p = (kernel.shape[0] - factor) - (convW - 1)
155
156
157
158
159
160
161
162
163

            stride = (factor, factor)
            # Determine data dimensions.
            output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
            output_padding = (
                output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
164
            inC = weight.shape[1]
165
166
167
            num_groups = x.shape[1] // inC

            # Transpose weights.
168
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
169
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
170
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
171

172
            x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
173

174
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
175
        else:
176
            p = kernel.shape[0] - factor
177
            x = upfirdn2d_native(
178
                x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
179
180
181
182
            )

        return x

183
    def forward(self, hidden_states):
184
        if self.use_conv:
185
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
186
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
187
        else:
188
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
189

190
        return height
191
192
193
194
195
196
197


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
198
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
199
200
201
202
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

203
    def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        """Fused `Conv2d()` followed by `downsample_2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
            x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
            filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
            numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
            Scaling factor for signal magnitude (default: 1.0).

        Returns:
            Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
            datatype as `x`.
        """
220

221
        assert isinstance(factor, int) and factor >= 1
222
223
        if kernel is None:
            kernel = [1] * factor
224

225
        # setup kernel
226
        kernel = torch.tensor(kernel, dtype=torch.float32)
227
        if kernel.ndim == 1:
228
229
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
230

231
        kernel = kernel * gain
232

233
        if self.use_conv:
234
235
            _, _, convH, convW = weight.shape
            p = (kernel.shape[0] - factor) + (convW - 1)
236
            s = [factor, factor]
237
238
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
            x = F.conv2d(x, weight, stride=s, padding=0)
239
        else:
240
241
            p = kernel.shape[0] - factor
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
242

243
        return x
244

245
    def forward(self, hidden_states):
246
        if self.use_conv:
247
248
            hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
249
        else:
250
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
251

252
        return hidden_states
253
254


255
class ResnetBlock2D(nn.Module):
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
272
        use_in_shortcut=None,
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

295
296
297
298
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
315
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
316
317
318
319
320
321
322
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
323
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
324
325
326
327
328
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

329
        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
330
331

        self.conv_shortcut = None
332
        if self.use_in_shortcut:
333
334
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

335
336
    def forward(self, x, temb):
        hidden_states = x
337

338
        hidden_states = self.norm1(hidden_states)
339
        hidden_states = self.nonlinearity(hidden_states)
340
341
342

        if self.upsample is not None:
            x = self.upsample(x)
343
            hidden_states = self.upsample(hidden_states)
344
345
        elif self.downsample is not None:
            x = self.downsample(x)
346
            hidden_states = self.downsample(hidden_states)
347

348
        hidden_states = self.conv1(hidden_states)
349
350
351

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
352
            hidden_states = hidden_states + temb
353

354
        hidden_states = self.norm2(hidden_states)
355
        hidden_states = self.nonlinearity(hidden_states)
356

357
358
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
359
360
361
362

        if self.conv_shortcut is not None:
            x = self.conv_shortcut(x)

363
        out = (x + hidden_states) / self.output_scale_factor
364
365
366

        return out

Patrick von Platen's avatar
Patrick von Platen committed
367
368
369
370
371
372

class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


373
def upsample_2d(x, kernel=None, factor=2, gain=1):
374
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
    multiple of the upsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]`
    """
    assert isinstance(factor, int) and factor >= 1
391
392
    if kernel is None:
        kernel = [1] * factor
393

394
    kernel = torch.tensor(kernel, dtype=torch.float32)
395
    if kernel.ndim == 1:
396
397
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
398

399
400
    kernel = kernel * (gain * (factor**2))
    p = kernel.shape[0] - factor
401
    return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
Patrick von Platen's avatar
Patrick von Platen committed
402
403


404
def downsample_2d(x, kernel=None, factor=2, gain=1):
405
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
406
407
408
409
410
411
412
413

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
414
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
415
416
417
418
419
420
421
422
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
        factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]`
    """

    assert isinstance(factor, int) and factor >= 1
423
424
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
425

426
    kernel = torch.tensor(kernel, dtype=torch.float32)
427
    if kernel.ndim == 1:
428
429
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
430

431
432
    kernel = kernel * gain
    p = kernel.shape[0] - factor
433
    return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448


def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
449
450
451
452

    # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
    if input.device.type == "mps":
        out = out.to("cpu")
453
454
455
456
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
457
    out = out.to(input.device)  # Move back to mps if necessary
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)