resnet.py 20.4 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
2

3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


8
class Upsample2D(nn.Module):
9
10
11
    """
    An upsampling layer with an optional convolution.

12
13
14
15
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
16
17
    """

18
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
19
20
21
22
23
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
24
        self.name = name
25

patil-suraj's avatar
patil-suraj committed
26
        conv = None
27
        if use_conv_transpose:
28
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
29
        elif use_conv:
30
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
31

32
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
33
34
35
36
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
37

38
    def forward(self, hidden_states, output_size=None):
39
        assert hidden_states.shape[1] == self.channels
40

41
        if self.use_conv_transpose:
42
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
43

44
45
46
47
48
49
50
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

51
52
53
54
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

55
56
57
58
59
60
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
61

62
63
64
65
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

66
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
67
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
68
            if self.name == "conv":
69
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
70
            else:
71
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
72

73
        return hidden_states
74
75


76
class Downsample2D(nn.Module):
77
78
79
    """
    A downsampling layer with an optional convolution.

80
81
82
83
    Parameters:
        channels: channels in the inputs and outputs.
        use_conv: a bool determining if a convolution is applied.
        dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
84
85
    """

86
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
87
88
89
90
91
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
92
        stride = 2
patil-suraj's avatar
patil-suraj committed
93
94
        self.name = name

95
        if use_conv:
96
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
97
98
        else:
            assert self.channels == self.out_channels
99
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
100

101
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
102
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
103
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
104
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
105
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
106
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
107
        else:
Patrick von Platen's avatar
Patrick von Platen committed
108
            self.conv = conv
109

110
111
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
112
        if self.use_conv and self.padding == 0:
113
            pad = (0, 1, 0, 1)
114
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
115

116
117
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
118

119
        return hidden_states
120
121
122
123
124
125
126
127
128
129
130
131


class FirUpsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

132
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
133
134
135
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
136
137
138
139
140
141
142
143
144
145
146
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
147
148

        Returns:
149
150
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
151
152
153
154
155
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
156
157
        if kernel is None:
            kernel = [1] * factor
158
159

        # setup kernel
160
        kernel = torch.tensor(kernel, dtype=torch.float32)
161
        if kernel.ndim == 1:
162
163
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
164

165
        kernel = kernel * (gain * (factor**2))
166
167

        if self.use_conv:
168
169
170
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
171

172
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
173
174
175

            stride = (factor, factor)
            # Determine data dimensions.
176
177
178
179
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
180
            output_padding = (
181
182
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
183
184
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
185
            num_groups = hidden_states.shape[1] // inC
186
187

            # Transpose weights.
188
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
189
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
190
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
191

192
193
194
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
195

196
197
198
199
200
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
201
        else:
202
203
204
205
206
207
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
208
209
            )

210
        return output
211

212
    def forward(self, hidden_states):
213
        if self.use_conv:
214
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
215
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
216
        else:
217
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
218

219
        return height
220
221
222
223
224
225
226


class FirDownsample2D(nn.Module):
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
227
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
228
229
230
231
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

232
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
233
        """Fused `Conv2d()` followed by `downsample_2d()`.
234
235
236
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
237
238

        Args:
239
240
241
242
243
244
245
246
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
247
248

        Returns:
249
250
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
251
        """
252

253
        assert isinstance(factor, int) and factor >= 1
254
255
        if kernel is None:
            kernel = [1] * factor
256

257
        # setup kernel
258
        kernel = torch.tensor(kernel, dtype=torch.float32)
259
        if kernel.ndim == 1:
260
261
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
262

263
        kernel = kernel * gain
264

265
        if self.use_conv:
266
            _, _, convH, convW = weight.shape
267
268
269
270
271
272
273
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
274
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
275
        else:
276
            pad_value = kernel.shape[0] - factor
277
            output = upfirdn2d_native(
278
279
280
281
282
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
283

284
        return output
285

286
    def forward(self, hidden_states):
287
        if self.use_conv:
288
289
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
290
        else:
291
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
292

293
        return hidden_states
294
295


296
class ResnetBlock2D(nn.Module):
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
313
        use_in_shortcut=None,
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

336
337
338
339
        if temb_channels is not None:
            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
356
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
357
358
359
360
361
362
363
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
364
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
365
366
367
368
369
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

370
        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
371
372

        self.conv_shortcut = None
373
        if self.use_in_shortcut:
374
375
            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

376
377
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
378

379
        hidden_states = self.norm1(hidden_states)
380
        hidden_states = self.nonlinearity(hidden_states)
381
382

        if self.upsample is not None:
383
384
385
386
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
387
            input_tensor = self.upsample(input_tensor)
388
            hidden_states = self.upsample(hidden_states)
389
        elif self.downsample is not None:
390
            input_tensor = self.downsample(input_tensor)
391
            hidden_states = self.downsample(hidden_states)
392

393
        hidden_states = self.conv1(hidden_states)
394
395
396

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
397
            hidden_states = hidden_states + temb
398

399
        hidden_states = self.norm2(hidden_states)
400
        hidden_states = self.nonlinearity(hidden_states)
401

402
403
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
404
405

        if self.conv_shortcut is not None:
406
            input_tensor = self.conv_shortcut(input_tensor)
407

408
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
409

410
        return output_tensor
411

Patrick von Platen's avatar
Patrick von Platen committed
412
413

class Mish(torch.nn.Module):
414
415
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
416
417


418
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
419
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
420
421
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
422
423
424
425
426
427
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
428
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
429
430
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
431
432

    Returns:
433
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
434
435
    """
    assert isinstance(factor, int) and factor >= 1
436
437
    if kernel is None:
        kernel = [1] * factor
438

439
    kernel = torch.tensor(kernel, dtype=torch.float32)
440
    if kernel.ndim == 1:
441
442
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
443

444
    kernel = kernel * (gain * (factor**2))
445
    pad_value = kernel.shape[0] - factor
446
    output = upfirdn2d_native(
447
448
449
450
451
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
452
    return output
Patrick von Platen's avatar
Patrick von Platen committed
453
454


455
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
456
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
457
458
459
460
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
461
462
463

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
464
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
465
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
466
467
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
468
469

    Returns:
470
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
471
472
473
    """

    assert isinstance(factor, int) and factor >= 1
474
475
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
476

477
    kernel = torch.tensor(kernel, dtype=torch.float32)
478
    if kernel.ndim == 1:
479
480
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
481

482
    kernel = kernel * gain
483
    pad_value = kernel.shape[0] - factor
484
    output = upfirdn2d_native(
485
486
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
487
    return output
488
489


490
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
491
492
493
494
495
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

496
497
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
498

499
    _, in_h, in_w, minor = tensor.shape
500
501
    kernel_h, kernel_w = kernel.shape

502
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
503
504
505
506
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
507
    out = out.to(tensor.device)  # Move back to mps if necessary
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)