resnet.py 19.6 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
class Upsample2D(nn.Module):
9
10
11
    """
    An upsampling layer with an optional convolution.

12
13
14
15
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
16
17
    """

18
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
19
20
21
22
23
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
24
        self.name = name
25

patil-suraj's avatar
patil-suraj committed
26
        conv = None
27
        if use_conv_transpose:
28
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
29
        elif use_conv:
30
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
31

32
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
33
34
35
36
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
37

38
    def forward(self, hidden_states, output_size=None):
39
        assert hidden_states.shape[1] == self.channels
40

41
        if self.use_conv_transpose:
42
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
43

44
45
46
47
48
49
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
50

51
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
52
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
53
            if self.name == "conv":
54
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
55
            else:
56
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
57

58
        return hidden_states
59
60


61
class Downsample2D(nn.Module):
62
63
64
    """
    A downsampling layer with an optional convolution.

65
66
67
68
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
69
70
    """

71
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
72
73
74
75
76
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
77
        stride = 2
patil-suraj's avatar
patil-suraj committed
78
79
        self.name = name

80
        if use_conv:
81
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
82
83
        else:
            assert self.channels == self.out_channels
84
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
85

86
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
87
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
88
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
89
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
90
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
91
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
92
        else:
Patrick von Platen's avatar
Patrick von Platen committed
93
            self.conv = conv
94

95
96
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
97
        if self.use_conv and self.padding == 0:
98
            pad = (0, 1, 0, 1)
99
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
100

101
102
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
103

104
        return hidden_states
105
106
107
108
109
110
111
112
113
114
115
116


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

117
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
118
119
120
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
121
122
123
124
125
126
127
128
129
130
131
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
132
133

        Returns:
134
135
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
136
137
138
139
140
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
141
142
        if kernel is None:
            kernel = [1] * factor
143
144

        # setup kernel
145
        kernel = torch.tensor(kernel, dtype=torch.float32)
146
        if kernel.ndim == 1:
147
148
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
149

150
        kernel = kernel * (gain * (factor**2))
151
152

        if self.use_conv:
153
154
155
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
156

157
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
158
159
160

            stride = (factor, factor)
            # Determine data dimensions.
161
162
163
164
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
165
            output_padding = (
166
167
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
168
169
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
170
            num_groups = hidden_states.shape[1] // inC
171
172

            # Transpose weights.
173
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
174
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
175
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
176

177
178
179
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
180

181
182
183
184
185
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
186
        else:
187
188
189
190
191
192
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
193
194
            )

195
        return output
196

197
    def forward(self, hidden_states):
198
        if self.use_conv:
199
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
200
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
201
        else:
202
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
203

204
        return height
205
206
207
208
209
210
211


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
212
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
213
214
215
216
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

217
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
218
        """Fused `Conv2d()` followed by `downsample_2d()`.
219
220
221
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
222
223

        Args:
224
225
226
227
228
229
230
231
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
232
233

        Returns:
234
235
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
236
        """
237

238
        assert isinstance(factor, int) and factor >= 1
239
240
        if kernel is None:
            kernel = [1] * factor
241

242
        # setup kernel
243
        kernel = torch.tensor(kernel, dtype=torch.float32)
244
        if kernel.ndim == 1:
245
246
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
247

248
        kernel = kernel * gain
249

250
        if self.use_conv:
251
            _, _, convH, convW = weight.shape
252
253
254
255
256
257
258
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
259
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
260
        else:
261
            pad_value = kernel.shape[0] - factor
262
            output = upfirdn2d_native(
263
264
265
266
267
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
268

269
        return output
270

271
    def forward(self, hidden_states):
272
        if self.use_conv:
273
274
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
275
        else:
276
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
277

278
        return hidden_states
279
280


281
class ResnetBlock2D(nn.Module):
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
298
        use_in_shortcut=None,
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

321
322
323
324
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
341
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
342
343
344
345
346
347
348
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
349
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
350
351
352
353
354
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

355
        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
356
357

        self.conv_shortcut = None
358
        if self.use_in_shortcut:
359
360
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

361
362
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
363

364
        hidden_states = self.norm1(hidden_states)
365
        hidden_states = self.nonlinearity(hidden_states)
366
367

        if self.upsample is not None:
368
            input_tensor = self.upsample(input_tensor)
369
            hidden_states = self.upsample(hidden_states)
370
        elif self.downsample is not None:
371
            input_tensor = self.downsample(input_tensor)
372
            hidden_states = self.downsample(hidden_states)
373

374
        hidden_states = self.conv1(hidden_states)
375
376
377

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
378
            hidden_states = hidden_states + temb
379

380
        hidden_states = self.norm2(hidden_states)
381
        hidden_states = self.nonlinearity(hidden_states)
382

383
384
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
385
386

        if self.conv_shortcut is not None:
387
            input_tensor = self.conv_shortcut(input_tensor)
388

389
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
390

391
        return output_tensor
392

Patrick von Platen's avatar
Patrick von Platen committed
393
394

class Mish(torch.nn.Module):
395
396
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
397
398


399
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
400
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
401
402
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
403
404
405
406
407
408
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
409
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
410
411
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
412
413

    Returns:
414
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
415
416
    """
    assert isinstance(factor, int) and factor >= 1
417
418
    if kernel is None:
        kernel = [1] * factor
419

420
    kernel = torch.tensor(kernel, dtype=torch.float32)
421
    if kernel.ndim == 1:
422
423
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
424

425
    kernel = kernel * (gain * (factor**2))
426
    pad_value = kernel.shape[0] - factor
427
    output = upfirdn2d_native(
428
429
430
431
432
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
433
    return output
Patrick von Platen's avatar
Patrick von Platen committed
434
435


436
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
437
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
438
439
440
441
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
442
443
444

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
445
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
446
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
447
448
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
449
450

    Returns:
451
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
452
453
454
    """

    assert isinstance(factor, int) and factor >= 1
455
456
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
457

458
    kernel = torch.tensor(kernel, dtype=torch.float32)
459
    if kernel.ndim == 1:
460
461
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
462

463
    kernel = kernel * gain
464
    pad_value = kernel.shape[0] - factor
465
    output = upfirdn2d_native(
466
467
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
468
    return output
469
470


471
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
472
473
474
475
476
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

477
478
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
479

480
    _, in_h, in_w, minor = tensor.shape
481
482
    kernel_h, kernel_w = kernel.shape

483
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
484
485

    # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
486
    if tensor.device.type == "mps":
487
        out = out.to("cpu")
488
489
490
491
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
492
    out = out.to(tensor.device)  # Move back to mps if necessary
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)