resnet.py 34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
24
from .attention import AdaGroupNorm

25

26
class Upsample1D(nn.Module):
27
    """A 1D upsampling layer with an optional convolution.
28
29

    Parameters:
30
31
32
33
34
35
36
37
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_conv_transpose:
            return self.conv(x)

        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

        if self.use_conv:
            x = self.conv(x)

        return x


class Downsample1D(nn.Module):
68
    """A 1D downsampling layer with an optional convolution.
69
70

    Parameters:
71
72
73
74
75
76
77
78
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.conv(x)


101
class Upsample2D(nn.Module):
102
    """A 2D upsampling layer with an optional convolution.
103

104
    Parameters:
105
106
107
108
109
110
111
112
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
113
114
    """

115
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
116
117
118
119
120
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
121
        self.name = name
122

patil-suraj's avatar
patil-suraj committed
123
        conv = None
124
        if use_conv_transpose:
125
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
126
        elif use_conv:
127
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
128

129
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
130
131
132
133
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
134

135
    def forward(self, hidden_states, output_size=None):
136
        assert hidden_states.shape[1] == self.channels
137

138
        if self.use_conv_transpose:
139
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
140

141
142
143
144
145
146
147
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

148
149
150
151
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

152
153
154
155
156
157
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
158

159
160
161
162
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

163
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
164
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
165
            if self.name == "conv":
166
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
167
            else:
168
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
169

170
        return hidden_states
171
172


173
class Downsample2D(nn.Module):
174
    """A 2D downsampling layer with an optional convolution.
175

176
    Parameters:
177
178
179
180
181
182
183
184
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
185
186
    """

187
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
188
189
190
191
192
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
193
        stride = 2
patil-suraj's avatar
patil-suraj committed
194
195
        self.name = name

196
        if use_conv:
197
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
198
199
        else:
            assert self.channels == self.out_channels
200
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
201

202
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
203
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
204
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
205
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
206
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
207
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
208
        else:
Patrick von Platen's avatar
Patrick von Platen committed
209
            self.conv = conv
210

211
212
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
213
        if self.use_conv and self.padding == 0:
214
            pad = (0, 1, 0, 1)
215
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
216

217
218
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
219

220
        return hidden_states
221
222
223


class FirUpsample2D(nn.Module):
224
225
226
227
228
229
230
231
232
233
234
235
236
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

237
238
239
240
241
242
243
244
245
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

246
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
247
248
249
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
250
251
252
253
254
255
256
257
258
259
260
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
261
262

        Returns:
263
264
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
265
266
267
268
269
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
270
271
        if kernel is None:
            kernel = [1] * factor
272
273

        # setup kernel
274
        kernel = torch.tensor(kernel, dtype=torch.float32)
275
        if kernel.ndim == 1:
276
277
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
278

279
        kernel = kernel * (gain * (factor**2))
280
281

        if self.use_conv:
282
283
284
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
285

286
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
287
288
289

            stride = (factor, factor)
            # Determine data dimensions.
290
291
292
293
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
294
            output_padding = (
295
296
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
297
298
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
299
            num_groups = hidden_states.shape[1] // inC
300
301

            # Transpose weights.
302
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
303
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
304
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
305

306
307
308
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
309

310
311
312
313
314
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
315
        else:
316
317
318
319
320
321
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
322
323
            )

324
        return output
325

326
    def forward(self, hidden_states):
327
        if self.use_conv:
328
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
329
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
330
        else:
331
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
332

333
        return height
334
335
336


class FirDownsample2D(nn.Module):
337
338
339
340
341
342
343
344
345
346
347
348
349
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

350
351
352
353
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
354
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
355
356
357
358
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

359
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
360
        """Fused `Conv2d()` followed by `downsample_2d()`.
361
362
363
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
364
365

        Args:
366
367
368
369
370
371
372
373
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
374
375

        Returns:
376
377
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
378
        """
379

380
        assert isinstance(factor, int) and factor >= 1
381
382
        if kernel is None:
            kernel = [1] * factor
383

384
        # setup kernel
385
        kernel = torch.tensor(kernel, dtype=torch.float32)
386
        if kernel.ndim == 1:
387
388
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
389

390
        kernel = kernel * gain
391

392
        if self.use_conv:
393
            _, _, convH, convW = weight.shape
394
395
396
397
398
399
400
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
401
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
402
        else:
403
            pad_value = kernel.shape[0] - factor
404
            output = upfirdn2d_native(
405
406
407
408
409
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
410

411
        return output
412

413
    def forward(self, hidden_states):
414
        if self.use_conv:
415
416
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
417
        else:
418
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
419

420
        return hidden_states
421
422


423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    def forward(self, x):
        x = F.pad(x, (self.pad,) * 4, self.pad_mode)
        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(x.shape[1], device=x.device)
        weight[indices, indices] = self.kernel.to(weight)
        return F.conv2d(x, weight, stride=2)


class KUpsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    def forward(self, x):
        x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(x.shape[1], device=x.device)
        weight[indices, indices] = self.kernel.to(weight)
        return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)


456
class ResnetBlock2D(nn.Module):
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
474
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
475
476
477
478
479
480
481
482
483
484
485
486
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

487
488
489
490
491
492
493
494
495
496
497
498
499
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
500
        skip_time_act=False,
501
        time_embedding_norm="default",  # default, scale_shift, ada_group
502
503
        kernel=None,
        output_scale_factor=1.0,
504
        use_in_shortcut=None,
505
506
        up=False,
        down=False,
507
508
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
509
510
511
512
513
514
515
516
517
518
519
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
520
        self.time_embedding_norm = time_embedding_norm
521
        self.skip_time_act = skip_time_act
522
523
524
525

        if groups_out is None:
            groups_out = groups

526
527
528
529
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
530
531
532

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

533
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
534
            if self.time_embedding_norm == "default":
535
                self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
536
            elif self.time_embedding_norm == "scale_shift":
537
538
539
                self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
            elif self.time_embedding_norm == "ada_group":
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
540
541
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
542
543
        else:
            self.time_emb_proj = None
544

545
546
547
548
549
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

550
        self.dropout = torch.nn.Dropout(dropout)
551
552
        conv_2d_out_channels = conv_2d_out_channels or out_channels
        self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
553
554
555
556

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
557
            self.nonlinearity = nn.Mish()
558
559
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()
560
561
        elif non_linearity == "gelu":
            self.nonlinearity = nn.GELU()
562
563
564
565
566

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
567
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
568
569
570
571
572
573
574
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
575
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
576
577
578
579
580
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

581
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
582
583

        self.conv_shortcut = None
584
        if self.use_in_shortcut:
585
586
587
            self.conv_shortcut = torch.nn.Conv2d(
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
588

589
590
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
591

592
593
594
595
596
        if self.time_embedding_norm == "ada_group":
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

597
        hidden_states = self.nonlinearity(hidden_states)
598
599

        if self.upsample is not None:
600
601
602
603
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
604
            input_tensor = self.upsample(input_tensor)
605
            hidden_states = self.upsample(hidden_states)
606
        elif self.downsample is not None:
607
            input_tensor = self.downsample(input_tensor)
608
            hidden_states = self.downsample(hidden_states)
609

610
        hidden_states = self.conv1(hidden_states)
611

612
        if self.time_emb_proj is not None:
613
614
615
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
            temb = self.time_emb_proj(temb)[:, :, None, None]
Will Berman's avatar
Will Berman committed
616
617

        if temb is not None and self.time_embedding_norm == "default":
618
            hidden_states = hidden_states + temb
619

620
621
622
623
        if self.time_embedding_norm == "ada_group":
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
624
625
626
627
628

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

629
        hidden_states = self.nonlinearity(hidden_states)
630

631
632
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
633
634

        if self.conv_shortcut is not None:
635
            input_tensor = self.conv_shortcut(input_tensor)
636

637
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
638

639
        return output_tensor
640

Patrick von Platen's avatar
Patrick von Platen committed
641
642

class Mish(torch.nn.Module):
643
644
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
645
646


647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
# unet_rl.py
def rearrange_dims(tensor):
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

    def forward(self, x):
        x = self.conv1d(x)
        x = rearrange_dims(x)
        x = self.group_norm(x)
        x = rearrange_dims(x)
        x = self.mish(x)
        return x


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

    def forward(self, x, t):
        """
        Args:
            x : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
        out = self.conv_in(x) + rearrange_dims(t)
        out = self.conv_out(out)
        return out + self.residual_conv(x)


710
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
711
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
712
713
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
714
715
716
717
718
719
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
720
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
721
722
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
723
724

    Returns:
725
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
726
727
    """
    assert isinstance(factor, int) and factor >= 1
728
729
    if kernel is None:
        kernel = [1] * factor
730

731
    kernel = torch.tensor(kernel, dtype=torch.float32)
732
    if kernel.ndim == 1:
733
734
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
735

736
    kernel = kernel * (gain * (factor**2))
737
    pad_value = kernel.shape[0] - factor
738
    output = upfirdn2d_native(
739
740
741
742
743
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
744
    return output
Patrick von Platen's avatar
Patrick von Platen committed
745
746


747
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
748
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
749
750
751
752
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
753
754
755

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
756
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
757
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
758
759
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
760
761

    Returns:
762
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
763
764
765
    """

    assert isinstance(factor, int) and factor >= 1
766
767
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
768

769
    kernel = torch.tensor(kernel, dtype=torch.float32)
770
    if kernel.ndim == 1:
771
772
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
773

774
    kernel = kernel * gain
775
    pad_value = kernel.shape[0] - factor
776
    output = upfirdn2d_native(
777
778
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
779
    return output
780
781


782
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
783
784
785
786
787
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

788
789
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
790

791
    _, in_h, in_w, minor = tensor.shape
792
793
    kernel_h, kernel_w = kernel.shape

794
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
795
796
797
798
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
799
    out = out.to(tensor.device)  # Move back to mps if necessary
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
    """

    def __init__(self, in_dim, out_dim=None, dropout=0.0):
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
            nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
        )
        self.conv2 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

    def forward(self, hidden_states, num_frames=1):
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states