resnet.py 18.1 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
class Upsample2D(nn.Module):
9
10
11
    """
    An upsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
12
13
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
14
15
16
                 upsampling occurs in the inner-two dimensions.
    """

17
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
18
19
20
21
22
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
23
        self.name = name
24

patil-suraj's avatar
patil-suraj committed
25
        conv = None
26
        if use_conv_transpose:
27
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
28
        elif use_conv:
29
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
30

31
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
32
33
34
35
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
36
37
38
39
40

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)
patil-suraj's avatar
patil-suraj committed
41

42
        x = F.interpolate(x, scale_factor=2.0, mode="nearest")
patil-suraj's avatar
patil-suraj committed
43

44
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
45
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
46
47
48
49
            if self.name == "conv":
                x = self.conv(x)
            else:
                x = self.Conv2d_0(x)
patil-suraj's avatar
patil-suraj committed
50

51
52
53
        return x


54
class Downsample2D(nn.Module):
55
56
57
    """
    A downsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
58
59
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
60
61
62
                 downsampling occurs in the inner-two dimensions.
    """

63
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
64
65
66
67
68
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
69
        stride = 2
patil-suraj's avatar
patil-suraj committed
70
71
        self.name = name

72
        if use_conv:
73
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
74
75
        else:
            assert self.channels == self.out_channels
76
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
77

78
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
79
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
80
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
81
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
82
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
83
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
84
        else:
Patrick von Platen's avatar
Patrick von Platen committed
85
            self.conv = conv
86
87
88

    def forward(self, x):
        assert x.shape[1] == self.channels
89
        if self.use_conv and self.padding == 0:
90
91
            pad = (0, 1, 0, 1)
            x = F.pad(x, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
92

93
94
95
96
        assert x.shape[1] == self.channels
        x = self.conv(x)

        return x
97
98
99
100
101
102
103
104
105
106
107
108


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

109
    def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
110
111
112
113
114
115
116
117
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
            C]`.
118
        weight: Weight tensor of the shape `[filterH, filterW, inChannels,
119
            outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
120
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
121
122
123
124
125
126
127
128
129
130
131
            (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

        Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
        `x`.
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
132
133
        if kernel is None:
            kernel = [1] * factor
134
135

        # setup kernel
136
        kernel = torch.tensor(kernel, dtype=torch.float32)
137
        if kernel.ndim == 1:
138
139
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
140

141
        kernel = kernel * (gain * (factor**2))
142
143

        if self.use_conv:
144
145
146
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
147

148
            p = (kernel.shape[0] - factor) - (convW - 1)
149
150
151
152
153
154
155
156
157
158

            stride = (factor, factor)
            # Determine data dimensions.
            stride = [1, 1, factor, factor]
            output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
            output_padding = (
                output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
159
            inC = weight.shape[1]
160
161
162
            num_groups = x.shape[1] // inC

            # Transpose weights.
163
164
165
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
            weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
166

167
            x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
168

169
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
170
        else:
171
            p = kernel.shape[0] - factor
172
            x = upfirdn2d_native(
173
                x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
174
175
176
177
            )

        return x

178
179
    def forward(self, x):
        if self.use_conv:
180
181
            height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
182
        else:
183
            height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
184

185
        return height
186
187
188
189
190
191
192


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
193
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
194
195
196
197
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

198
    def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        """Fused `Conv2d()` followed by `downsample_2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
            x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
            filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
            numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
            Scaling factor for signal magnitude (default: 1.0).

        Returns:
            Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
            datatype as `x`.
        """
215

216
        assert isinstance(factor, int) and factor >= 1
217
218
        if kernel is None:
            kernel = [1] * factor
219

220
        # setup kernel
221
        kernel = torch.tensor(kernel, dtype=torch.float32)
222
        if kernel.ndim == 1:
223
224
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
225

226
        kernel = kernel * gain
227

228
        if self.use_conv:
229
230
            _, _, convH, convW = weight.shape
            p = (kernel.shape[0] - factor) + (convW - 1)
231
            s = [factor, factor]
232
233
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
            x = F.conv2d(x, weight, stride=s, padding=0)
234
        else:
235
236
            p = kernel.shape[0] - factor
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
237

238
        return x
239

240
241
    def forward(self, x):
        if self.use_conv:
242
            x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
243
244
            x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
        else:
245
            x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
246

247
        return x
248
249


250
class ResnetBlock2D(nn.Module):
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
        use_nin_shortcut=None,
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

290
291
292
293
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
310
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
311
312
313
314
315
316
317
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
318
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
319
320
321
322
323
324
325
326
327
328
329
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

        self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut

        self.conv_shortcut = None
        if self.use_nin_shortcut:
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

330
331
    def forward(self, x, temb):
        hidden_states = x
332

333
334
        # make sure hidden states is in float32
        # when running in half-precision
335
        hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
336
        hidden_states = self.nonlinearity(hidden_states)
337
338
339

        if self.upsample is not None:
            x = self.upsample(x)
340
            hidden_states = self.upsample(hidden_states)
341
342
        elif self.downsample is not None:
            x = self.downsample(x)
343
            hidden_states = self.downsample(hidden_states)
344

345
        hidden_states = self.conv1(hidden_states)
346
347
348

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
349
            hidden_states = hidden_states + temb
350

351
352
        # make sure hidden states is in float32
        # when running in half-precision
353
        hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
354
        hidden_states = self.nonlinearity(hidden_states)
355

356
357
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
358
359
360
361

        if self.conv_shortcut is not None:
            x = self.conv_shortcut(x)

362
        out = (x + hidden_states) / self.output_scale_factor
363
364
365

        return out

Patrick von Platen's avatar
Patrick von Platen committed
366
367
368
369
370
371

class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


372
def upsample_2d(x, kernel=None, factor=2, gain=1):
373
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
    multiple of the upsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]`
    """
    assert isinstance(factor, int) and factor >= 1
390
391
    if kernel is None:
        kernel = [1] * factor
392

393
    kernel = torch.tensor(kernel, dtype=torch.float32)
394
    if kernel.ndim == 1:
395
396
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
397

398
399
    kernel = kernel * (gain * (factor**2))
    p = kernel.shape[0] - factor
400
    return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
Patrick von Platen's avatar
Patrick von Platen committed
401
402


403
def downsample_2d(x, kernel=None, factor=2, gain=1):
404
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
405
406
407
408
409
410
411
412

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
413
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
414
415
416
417
418
419
420
421
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
        factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]`
    """

    assert isinstance(factor, int) and factor >= 1
422
423
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
424

425
    kernel = torch.tensor(kernel, dtype=torch.float32)
426
    if kernel.ndim == 1:
427
428
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
429

430
431
    kernel = kernel * gain
    p = kernel.shape[0] - factor
432
    return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447


def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
448
449
450
451

    # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
    if input.device.type == "mps":
        out = out.to("cpu")
452
453
454
455
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
456
    out = out.to(input.device)  # Move back to mps if necessary
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)