resnet.py 19.9 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
class Upsample2D(nn.Module):
9
10
11
    """
    An upsampling layer with an optional convolution.

12
13
14
15
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
16
17
    """

18
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
19
20
21
22
23
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
24
        self.name = name
25

patil-suraj's avatar
patil-suraj committed
26
        conv = None
27
        if use_conv_transpose:
28
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
29
        elif use_conv:
30
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
31

32
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
33
34
35
36
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
37

38
    def forward(self, hidden_states, output_size=None):
39
        assert hidden_states.shape[1] == self.channels
40

41
        if self.use_conv_transpose:
42
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
43

44
45
46
47
48
49
50
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

51
52
53
54
55
56
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
57

58
59
60
61
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

62
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
63
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
64
            if self.name == "conv":
65
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
66
            else:
67
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
68

69
        return hidden_states
70
71


72
class Downsample2D(nn.Module):
73
74
75
    """
    A downsampling layer with an optional convolution.

76
77
78
79
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
80
81
    """

82
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
83
84
85
86
87
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
88
        stride = 2
patil-suraj's avatar
patil-suraj committed
89
90
        self.name = name

91
        if use_conv:
92
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
93
94
        else:
            assert self.channels == self.out_channels
95
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
96

97
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
98
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
99
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
100
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
101
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
102
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
103
        else:
Patrick von Platen's avatar
Patrick von Platen committed
104
            self.conv = conv
105

106
107
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
108
        if self.use_conv and self.padding == 0:
109
            pad = (0, 1, 0, 1)
110
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
111

112
113
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
114

115
        return hidden_states
116
117
118
119
120
121
122
123
124
125
126
127


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

128
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
129
130
131
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
132
133
134
135
136
137
138
139
140
141
142
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
143
144

        Returns:
145
146
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
147
148
149
150
151
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
152
153
        if kernel is None:
            kernel = [1] * factor
154
155

        # setup kernel
156
        kernel = torch.tensor(kernel, dtype=torch.float32)
157
        if kernel.ndim == 1:
158
159
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
160

161
        kernel = kernel * (gain * (factor**2))
162
163

        if self.use_conv:
164
165
166
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
167

168
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
169
170
171

            stride = (factor, factor)
            # Determine data dimensions.
172
173
174
175
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
176
            output_padding = (
177
178
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
179
180
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
181
            num_groups = hidden_states.shape[1] // inC
182
183

            # Transpose weights.
184
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
185
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
186
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
187

188
189
190
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
191

192
193
194
195
196
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
197
        else:
198
199
200
201
202
203
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
204
205
            )

206
        return output
207

208
    def forward(self, hidden_states):
209
        if self.use_conv:
210
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
211
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
212
        else:
213
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
214

215
        return height
216
217
218
219
220
221
222


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
223
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
224
225
226
227
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

228
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
229
        """Fused `Conv2d()` followed by `downsample_2d()`.
230
231
232
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
233
234

        Args:
235
236
237
238
239
240
241
242
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
243
244

        Returns:
245
246
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
247
        """
248

249
        assert isinstance(factor, int) and factor >= 1
250
251
        if kernel is None:
            kernel = [1] * factor
252

253
        # setup kernel
254
        kernel = torch.tensor(kernel, dtype=torch.float32)
255
        if kernel.ndim == 1:
256
257
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
258

259
        kernel = kernel * gain
260

261
        if self.use_conv:
262
            _, _, convH, convW = weight.shape
263
264
265
266
267
268
269
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
270
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
271
        else:
272
            pad_value = kernel.shape[0] - factor
273
            output = upfirdn2d_native(
274
275
276
277
278
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
279

280
        return output
281

282
    def forward(self, hidden_states):
283
        if self.use_conv:
284
285
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
286
        else:
287
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
288

289
        return hidden_states
290
291


292
class ResnetBlock2D(nn.Module):
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
309
        use_in_shortcut=None,
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

332
333
334
335
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
352
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
353
354
355
356
357
358
359
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
360
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
361
362
363
364
365
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

366
        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
367
368

        self.conv_shortcut = None
369
        if self.use_in_shortcut:
370
371
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

372
373
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
374

375
        hidden_states = self.norm1(hidden_states)
376
        hidden_states = self.nonlinearity(hidden_states)
377
378

        if self.upsample is not None:
379
            input_tensor = self.upsample(input_tensor)
380
            hidden_states = self.upsample(hidden_states)
381
        elif self.downsample is not None:
382
            input_tensor = self.downsample(input_tensor)
383
            hidden_states = self.downsample(hidden_states)
384

385
        hidden_states = self.conv1(hidden_states)
386
387
388

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
389
            hidden_states = hidden_states + temb
390

391
        hidden_states = self.norm2(hidden_states)
392
        hidden_states = self.nonlinearity(hidden_states)
393

394
395
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
396
397

        if self.conv_shortcut is not None:
398
            input_tensor = self.conv_shortcut(input_tensor)
399

400
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
401

402
        return output_tensor
403

Patrick von Platen's avatar
Patrick von Platen committed
404
405

class Mish(torch.nn.Module):
406
407
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
408
409


410
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
411
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
412
413
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
414
415
416
417
418
419
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
420
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
421
422
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
423
424

    Returns:
425
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
426
427
    """
    assert isinstance(factor, int) and factor >= 1
428
429
    if kernel is None:
        kernel = [1] * factor
430

431
    kernel = torch.tensor(kernel, dtype=torch.float32)
432
    if kernel.ndim == 1:
433
434
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
435

436
    kernel = kernel * (gain * (factor**2))
437
    pad_value = kernel.shape[0] - factor
438
    output = upfirdn2d_native(
439
440
441
442
443
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
444
    return output
Patrick von Platen's avatar
Patrick von Platen committed
445
446


447
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
448
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
449
450
451
452
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
453
454
455

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
456
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
457
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
458
459
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
460
461

    Returns:
462
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
463
464
465
    """

    assert isinstance(factor, int) and factor >= 1
466
467
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
468

469
    kernel = torch.tensor(kernel, dtype=torch.float32)
470
    if kernel.ndim == 1:
471
472
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
473

474
    kernel = kernel * gain
475
    pad_value = kernel.shape[0] - factor
476
    output = upfirdn2d_native(
477
478
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
479
    return output
480
481


482
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
483
484
485
486
487
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

488
489
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
490

491
    _, in_h, in_w, minor = tensor.shape
492
493
    kernel_h, kernel_w = kernel.shape

494
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
495
496
497
498
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
499
    out = out.to(tensor.device)  # Move back to mps if necessary
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)