resnet.py 18.3 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
class Upsample2D(nn.Module):
9
10
11
    """
    An upsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
12
13
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
14
15
16
                 upsampling occurs in the inner-two dimensions.
    """

17
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
18
19
20
21
22
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
23
        self.name = name
24

patil-suraj's avatar
patil-suraj committed
25
        conv = None
26
        if use_conv_transpose:
27
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
28
        elif use_conv:
29
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
30

31
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
32
33
34
35
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
36

37
38
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
39
        if self.use_conv_transpose:
40
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
41

42
        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
patil-suraj's avatar
patil-suraj committed
43

44
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
45
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
46
            if self.name == "conv":
47
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
48
            else:
49
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
50

51
        return hidden_states
52
53


54
class Downsample2D(nn.Module):
55
56
57
    """
    A downsampling layer with an optional convolution.

Patrick von Platen's avatar
Patrick von Platen committed
58
59
    :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
    applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
60
61
62
                 downsampling occurs in the inner-two dimensions.
    """

63
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
64
65
66
67
68
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
69
        stride = 2
patil-suraj's avatar
patil-suraj committed
70
71
        self.name = name

72
        if use_conv:
73
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
74
75
        else:
            assert self.channels == self.out_channels
76
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
77

78
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
79
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
80
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
81
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
82
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
83
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
84
        else:
Patrick von Platen's avatar
Patrick von Platen committed
85
            self.conv = conv
86

87
88
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
89
        if self.use_conv and self.padding == 0:
90
            pad = (0, 1, 0, 1)
91
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
92

93
94
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
95

96
        return hidden_states
97
98
99
100
101
102
103
104
105
106
107
108


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

109
    def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
110
111
112
113
114
115
116
117
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
            C]`.
118
        weight: Weight tensor of the shape `[filterH, filterW, inChannels,
119
            outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
120
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
121
122
123
124
125
126
127
128
129
130
131
            (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

        Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
        `x`.
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
132
133
        if kernel is None:
            kernel = [1] * factor
134
135

        # setup kernel
136
        kernel = torch.tensor(kernel, dtype=torch.float32)
137
        if kernel.ndim == 1:
138
139
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
140

141
        kernel = kernel * (gain * (factor**2))
142
143

        if self.use_conv:
144
145
146
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
147

148
            p = (kernel.shape[0] - factor) - (convW - 1)
149
150
151
152
153
154
155
156
157

            stride = (factor, factor)
            # Determine data dimensions.
            output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
            output_padding = (
                output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
158
            inC = weight.shape[1]
159
160
161
            num_groups = x.shape[1] // inC

            # Transpose weights.
162
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
163
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
164
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
165

166
            x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
167

168
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
169
        else:
170
            p = kernel.shape[0] - factor
171
            x = upfirdn2d_native(
172
                x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
173
174
175
176
            )

        return x

177
    def forward(self, hidden_states):
178
        if self.use_conv:
179
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
180
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
181
        else:
182
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
183

184
        return height
185
186
187
188
189
190
191


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
192
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
193
194
195
196
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

197
    def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        """Fused `Conv2d()` followed by `downsample_2d()`.

        Args:
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
        order.
            x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
            filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
            numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
            Scaling factor for signal magnitude (default: 1.0).

        Returns:
            Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
            datatype as `x`.
        """
214

215
        assert isinstance(factor, int) and factor >= 1
216
217
        if kernel is None:
            kernel = [1] * factor
218

219
        # setup kernel
220
        kernel = torch.tensor(kernel, dtype=torch.float32)
221
        if kernel.ndim == 1:
222
223
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
224

225
        kernel = kernel * gain
226

227
        if self.use_conv:
228
229
            _, _, convH, convW = weight.shape
            p = (kernel.shape[0] - factor) + (convW - 1)
230
            s = [factor, factor]
231
232
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
            x = F.conv2d(x, weight, stride=s, padding=0)
233
        else:
234
235
            p = kernel.shape[0] - factor
            x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
236

237
        return x
238

239
    def forward(self, hidden_states):
240
        if self.use_conv:
241
242
            hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
243
        else:
244
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
245

246
        return hidden_states
247
248


249
class ResnetBlock2D(nn.Module):
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
266
        use_in_shortcut=None,
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

289
290
291
292
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
309
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
310
311
312
313
314
315
316
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
317
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
318
319
320
321
322
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

323
        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
324
325

        self.conv_shortcut = None
326
        if self.use_in_shortcut:
327
328
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

329
330
    def forward(self, x, temb):
        hidden_states = x
331

332
333
        # make sure hidden states is in float32
        # when running in half-precision
334
        hidden_states = self.norm1(hidden_states)
335
        hidden_states = self.nonlinearity(hidden_states)
336
337
338

        if self.upsample is not None:
            x = self.upsample(x)
339
            hidden_states = self.upsample(hidden_states)
340
341
        elif self.downsample is not None:
            x = self.downsample(x)
342
            hidden_states = self.downsample(hidden_states)
343

344
        hidden_states = self.conv1(hidden_states)
345
346
347

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
348
            hidden_states = hidden_states + temb
349

350
351
        # make sure hidden states is in float32
        # when running in half-precision
352
        hidden_states = self.norm2(hidden_states)
353
        hidden_states = self.nonlinearity(hidden_states)
354

355
356
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
357
358
359
360

        if self.conv_shortcut is not None:
            x = self.conv_shortcut(x)

361
        out = (x + hidden_states) / self.output_scale_factor
362
363
364

        return out

Patrick von Platen's avatar
Patrick von Platen committed
365
366
367
368
369
370

class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


371
def upsample_2d(x, kernel=None, factor=2, gain=1):
372
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
    multiple of the upsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]`
    """
    assert isinstance(factor, int) and factor >= 1
389
390
    if kernel is None:
        kernel = [1] * factor
391

392
    kernel = torch.tensor(kernel, dtype=torch.float32)
393
    if kernel.ndim == 1:
394
395
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
396

397
398
    kernel = kernel * (gain * (factor**2))
    p = kernel.shape[0] - factor
399
    return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
Patrick von Platen's avatar
Patrick von Platen committed
400
401


402
def downsample_2d(x, kernel=None, factor=2, gain=1):
403
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
404
405
406
407
408
409
410
411

    Args:
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
412
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
413
414
415
416
417
418
419
420
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
        factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).

    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]`
    """

    assert isinstance(factor, int) and factor >= 1
421
422
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
423

424
    kernel = torch.tensor(kernel, dtype=torch.float32)
425
    if kernel.ndim == 1:
426
427
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
428

429
430
    kernel = kernel * gain
    p = kernel.shape[0] - factor
431
    return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446


def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
447
448
449
450

    # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
    if input.device.type == "mps":
        out = out.to("cpu")
451
452
453
454
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
455
    out = out.to(input.device)  # Move back to mps if necessary
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)