resnet.py 17.9 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2
3

import numpy as np
4
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F


9
class Upsample2D(nn.Module):
10
11
12
    """
    An upsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
13
14
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
15
16
17
                 upsampling occurs in the inner-two dimensions.
    """

18
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
19
20
21
22
23
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
24
        self.name = name
25

patil-suraj's avatar
patil-suraj committed
26
        conv = None
27
        if use_conv_transpose:
28
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
29
        elif use_conv:
30
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
31

32
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
33
34
35
36
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
37
38
39
40
41

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)
patil-suraj's avatar
patil-suraj committed
42

43
        x = F.interpolate(x, scale_factor=2.0, mode="nearest")
patil-suraj's avatar
patil-suraj committed
44

45
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
46
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
47
48
49
50
            if self.name == "conv":
                x = self.conv(x)
            else:
                x = self.Conv2d_0(x)
patil-suraj's avatar
patil-suraj committed
51

52
53
54
        return x


55
class Downsample2D(nn.Module):
56
57
58
    """
    A downsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
59
60
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
61
62
63
                 downsampling occurs in the inner-two dimensions.
    """

64
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
65
66
67
68
69
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
70
        stride = 2
patil-suraj's avatar
patil-suraj committed
71
72
        self.name = name

73
        if use_conv:
74
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
75
76
        else:
            assert self.channels == self.out_channels
77
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
78

79
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
80
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
81
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
82
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
83
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
84
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
85
        else:
Patrick von Platen's avatar
Patrick von Platen committed
86
            self.conv = conv
87
88
89

    def forward(self, x):
        assert x.shape[1] == self.channels
90
        if self.use_conv and self.padding == 0:
91
92
            pad = (0, 1, 0, 1)
            x = F.pad(x, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
93

94
95
96
97
        assert x.shape[1] == self.channels
        x = self.conv(x)

        return x
98
99
100
101
102
103
104
105
106
107
108
109


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

110
    def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
111
112
113
114
115
116
117
118
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
            C]`.
119
        weight: Weight tensor of the shape `[filterH, filterW, inChannels,
120
            outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
121
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
122
123
124
125
126
127
128
129
130
131
132
            (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

        Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
        `x`.
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
133
134
        if kernel is None:
            kernel = [1] * factor
135
136

        # setup kernel
137
138
139
140
        kernel = np.asarray(kernel, dtype=np.float32)
        if kernel.ndim == 1:
            kernel = np.outer(kernel, kernel)
        kernel /= np.sum(kernel)
141

142
        kernel = kernel * (gain * (factor**2))
143
144

        if self.use_conv:
145
146
147
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
148

149
            p = (kernel.shape[0] - factor) - (convW - 1)
150
151
152
153
154
155
156
157
158
159

            stride = (factor, factor)
            # Determine data dimensions.
            stride = [1, 1, factor, factor]
            output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
            output_padding = (
                output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
160
            inC = weight.shape[1]
161
162
163
            num_groups = x.shape[1] // inC

            # Transpose weights.
164
165
166
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
            weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
167

168
            x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
169

170
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
171
        else:
172
            p = kernel.shape[0] - factor
173
            x = upfirdn2d_native(
174
                x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
175
176
177
178
            )

        return x

179
180
    def forward(self, x):
        if self.use_conv:
181
182
            height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
183
        else:
184
            height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
185

186
        return height
187
188
189
190
191
192
193


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
194
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
195
196
197
198
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

199
    def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        """Fused `Conv2d()` followed by `downsample_2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
            x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
            filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
            numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
            Scaling factor for signal magnitude (default: 1.0).

        Returns:
            Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
            datatype as `x`.
        """
216

217
        assert isinstance(factor, int) and factor >= 1
218
219
        if kernel is None:
            kernel = [1] * factor
220

221
        # setup kernel
222
223
224
225
        kernel = np.asarray(kernel, dtype=np.float32)
        if kernel.ndim == 1:
            kernel = np.outer(kernel, kernel)
        kernel /= np.sum(kernel)
226

227
        kernel = kernel * gain
228

229
        if self.use_conv:
230
231
            _, _, convH, convW = weight.shape
            p = (kernel.shape[0] - factor) + (convW - 1)
232
            s = [factor, factor]
233
234
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
            x = F.conv2d(x, weight, stride=s, padding=0)
235
        else:
236
237
            p = kernel.shape[0] - factor
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
238

239
        return x
240

241
242
    def forward(self, x):
        if self.use_conv:
243
            x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
244
245
            x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
        else:
246
            x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
247

248
        return x
249
250


251
class ResnetBlock2D(nn.Module):
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
        use_nin_shortcut=None,
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

291
292
293
294
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
311
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
312
313
314
315
316
317
318
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
319
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
320
321
322
323
324
325
326
327
328
329
330
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

        self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut

        self.conv_shortcut = None
        if self.use_nin_shortcut:
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

331
332
    def forward(self, x, temb):
        hidden_states = x
333

334
335
        # make sure hidden states is in float32
        # when running in half-precision
336
337
        hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
        hidden_states = self.nonlinearity(hidden_states)
338
339
340

        if self.upsample is not None:
            x = self.upsample(x)
341
            hidden_states = self.upsample(hidden_states)
342
343
        elif self.downsample is not None:
            x = self.downsample(x)
344
            hidden_states = self.downsample(hidden_states)
345

346
        hidden_states = self.conv1(hidden_states)
347
348
349

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
350
            hidden_states = hidden_states + temb
351

352
353
        # make sure hidden states is in float32
        # when running in half-precision
354
355
        hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
        hidden_states = self.nonlinearity(hidden_states)
356

357
358
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
359
360
361
362

        if self.conv_shortcut is not None:
            x = self.conv_shortcut(x)

363
        out = (x + hidden_states) / self.output_scale_factor
364
365
366

        return out

Patrick von Platen's avatar
Patrick von Platen committed
367
368
369
370
371
372

class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


373
def upsample_2d(x, kernel=None, factor=2, gain=1):
374
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
    multiple of the upsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]`
    """
    assert isinstance(factor, int) and factor >= 1
391
392
    if kernel is None:
        kernel = [1] * factor
393

394
395
396
397
    kernel = np.asarray(kernel, dtype=np.float32)
    if kernel.ndim == 1:
        kernel = np.outer(kernel, kernel)
    kernel /= np.sum(kernel)
398

399
400
401
402
403
    kernel = kernel * (gain * (factor**2))
    p = kernel.shape[0] - factor
    return upfirdn2d_native(
        x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
    )
Patrick von Platen's avatar
Patrick von Platen committed
404
405


406
def downsample_2d(x, kernel=None, factor=2, gain=1):
407
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
408
409
410
411
412
413
414
415

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
416
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
417
418
419
420
421
422
423
424
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
        factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]`
    """

    assert isinstance(factor, int) and factor >= 1
425
426
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
427

428
429
430
431
    kernel = np.asarray(kernel, dtype=np.float32)
    if kernel.ndim == 1:
        kernel = np.outer(kernel, kernel)
    kernel /= np.sum(kernel)
432

433
434
435
    kernel = kernel * gain
    p = kernel.shape[0] - factor
    return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478


def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)