resnet.py 38.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from .activations import get_activation
24
from .attention import AdaGroupNorm
YiYi Xu's avatar
YiYi Xu committed
25
from .attention_processor import SpatialNorm
26
from .lora import LoRACompatibleConv, LoRACompatibleLinear
27

28

29
class Upsample1D(nn.Module):
30
    """A 1D upsampling layer with an optional convolution.
31
32

    Parameters:
33
34
35
36
37
38
39
40
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
41
42
        name (`str`, default `conv`):
            name of the upsampling 1D layer.
43
44
    """

45
46
47
48
49
50
51
52
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
53
54
55
56
57
58
59
60
61
62
63
64
65
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

66
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
67
        assert inputs.shape[1] == self.channels
68
        if self.use_conv_transpose:
69
            return self.conv(inputs)
70

71
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
72
73

        if self.use_conv:
74
            outputs = self.conv(outputs)
75

76
        return outputs
77
78
79


class Downsample1D(nn.Module):
80
    """A 1D downsampling layer with an optional convolution.
81
82

    Parameters:
83
84
85
86
87
88
89
90
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
91
92
        name (`str`, default `conv`):
            name of the downsampling 1D layer.
93
94
    """

95
96
97
98
99
100
101
102
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
    ):
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

117
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
118
119
        assert inputs.shape[1] == self.channels
        return self.conv(inputs)
120
121


122
class Upsample2D(nn.Module):
123
    """A 2D upsampling layer with an optional convolution.
124

125
    Parameters:
126
127
128
129
130
131
132
133
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
134
135
        name (`str`, default `conv`):
            name of the upsampling 2D layer.
136
137
    """

138
139
140
141
142
143
144
145
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
146
147
148
149
150
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
151
        self.name = name
152

patil-suraj's avatar
patil-suraj committed
153
        conv = None
154
        if use_conv_transpose:
155
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
156
        elif use_conv:
157
            conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
158

159
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
160
161
162
163
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
164

165
    def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
166
        assert hidden_states.shape[1] == self.channels
167

168
        if self.use_conv_transpose:
169
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
170

171
172
173
174
175
176
177
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

178
179
180
181
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

182
183
184
185
186
187
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
188

189
190
191
192
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

193
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
194
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
195
            if self.name == "conv":
196
197
198
199
                if isinstance(self.conv, LoRACompatibleConv):
                    hidden_states = self.conv(hidden_states, scale)
                else:
                    hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
200
            else:
201
202
203
204
                if isinstance(self.Conv2d_0, LoRACompatibleConv):
                    hidden_states = self.Conv2d_0(hidden_states, scale)
                else:
                    hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
205

206
        return hidden_states
207
208


209
class Downsample2D(nn.Module):
210
    """A 2D downsampling layer with an optional convolution.
211

212
    Parameters:
213
214
215
216
217
218
219
220
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
221
222
        name (`str`, default `conv`):
            name of the downsampling 2D layer.
223
224
    """

225
226
227
228
229
230
231
232
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
    ):
233
234
235
236
237
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
238
        stride = 2
patil-suraj's avatar
patil-suraj committed
239
240
        self.name = name

241
        if use_conv:
242
            conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
243
244
        else:
            assert self.channels == self.out_channels
245
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
246

247
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
248
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
249
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
250
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
251
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
252
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
253
        else:
Patrick von Platen's avatar
Patrick von Platen committed
254
            self.conv = conv
255

256
    def forward(self, hidden_states, scale: float = 1.0):
257
        assert hidden_states.shape[1] == self.channels
258
        if self.use_conv and self.padding == 0:
259
            pad = (0, 1, 0, 1)
260
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
261

262
        assert hidden_states.shape[1] == self.channels
263
264
265
266
        if isinstance(self.conv, LoRACompatibleConv):
            hidden_states = self.conv(hidden_states, scale)
        else:
            hidden_states = self.conv(hidden_states)
267

268
        return hidden_states
269
270
271


class FirUpsample2D(nn.Module):
272
273
274
275
276
277
278
279
280
281
282
283
284
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

285
286
287
288
289
290
291
    def __init__(
        self,
        channels: int = None,
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
292
293
294
295
296
297
298
299
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

300
301
302
303
304
305
306
307
    def _upsample_2d(
        self,
        hidden_states: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        kernel: Optional[torch.FloatTensor] = None,
        factor: int = 2,
        gain: float = 1,
    ) -> torch.Tensor:
308
309
310
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
311
312
313
314
315
316
317
318
319
320
321
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
322
323

        Returns:
324
325
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
326
327
328
329
330
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
331
332
        if kernel is None:
            kernel = [1] * factor
333
334

        # setup kernel
335
        kernel = torch.tensor(kernel, dtype=torch.float32)
336
        if kernel.ndim == 1:
337
338
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
339

340
        kernel = kernel * (gain * (factor**2))
341
342

        if self.use_conv:
343
344
345
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
346

347
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
348
349
350

            stride = (factor, factor)
            # Determine data dimensions.
351
352
353
354
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
355
            output_padding = (
356
357
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
358
359
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
360
            num_groups = hidden_states.shape[1] // inC
361
362

            # Transpose weights.
363
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
364
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
365
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
366

367
368
369
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
370

371
372
373
374
375
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
376
        else:
377
378
379
380
381
382
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
383
384
            )

385
        return output
386

387
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
388
        if self.use_conv:
389
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
390
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
391
        else:
392
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
393

394
        return height
395
396
397


class FirDownsample2D(nn.Module):
398
399
400
401
402
403
404
405
406
407
408
409
410
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

411
412
413
414
415
416
417
    def __init__(
        self,
        channels: int = None,
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
418
419
420
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
421
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
422
423
424
425
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

426
427
428
429
430
431
432
433
    def _downsample_2d(
        self,
        hidden_states: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        kernel: Optional[torch.FloatTensor] = None,
        factor: int = 2,
        gain: float = 1,
    ) -> torch.Tensor:
434
        """Fused `Conv2d()` followed by `downsample_2d()`.
435
436
437
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
438
439

        Args:
440
441
442
443
444
445
446
447
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
448
449

        Returns:
450
451
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
452
        """
453

454
        assert isinstance(factor, int) and factor >= 1
455
456
        if kernel is None:
            kernel = [1] * factor
457

458
        # setup kernel
459
        kernel = torch.tensor(kernel, dtype=torch.float32)
460
        if kernel.ndim == 1:
461
462
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
463

464
        kernel = kernel * gain
465

466
        if self.use_conv:
467
            _, _, convH, convW = weight.shape
468
469
470
471
472
473
474
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
475
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
476
        else:
477
            pad_value = kernel.shape[0] - factor
478
            output = upfirdn2d_native(
479
480
481
482
483
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
484

485
        return output
486

487
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
488
        if self.use_conv:
489
490
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
491
        else:
492
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
493

494
        return hidden_states
495
496


497
498
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
499
500
501
502
503
504
505
    r"""A 2D K-downsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    def __init__(self, pad_mode: str = "reflect"):
506
507
508
509
510
511
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

512
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
513
514
515
516
        inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
517
        weight[indices, indices] = kernel
518
        return F.conv2d(inputs, weight, stride=2)
519
520
521


class KUpsample2D(nn.Module):
522
523
524
525
526
527
528
    r"""A 2D K-upsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    def __init__(self, pad_mode: str = "reflect"):
529
530
531
532
533
534
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

535
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
536
537
538
539
        inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
540
        weight[indices, indices] = kernel
541
        return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
542
543


544
class ResnetBlock2D(nn.Module):
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
562
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
563
564
565
566
567
568
569
570
571
572
573
574
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

575
576
577
    def __init__(
        self,
        *,
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        in_channels: int,
        out_channels: Optional[int] = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        temb_channels: int = 512,
        groups: int = 32,
        groups_out: Optional[int] = None,
        pre_norm: bool = True,
        eps: float = 1e-6,
        non_linearity: str = "swish",
        skip_time_act: bool = False,
        time_embedding_norm: str = "default",  # default, scale_shift, ada_group, spatial
        kernel: Optional[torch.FloatTensor] = None,
        output_scale_factor: float = 1.0,
        use_in_shortcut: Optional[bool] = None,
        up: bool = False,
        down: bool = False,
595
596
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
597
598
599
600
601
602
603
604
605
606
607
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
608
        self.time_embedding_norm = time_embedding_norm
609
        self.skip_time_act = skip_time_act
610
611
612
613

        if groups_out is None:
            groups_out = groups

614
615
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
616
617
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
618
619
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
620

621
        self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
622

623
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
624
            if self.time_embedding_norm == "default":
625
                self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
626
            elif self.time_embedding_norm == "scale_shift":
627
                self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
YiYi Xu's avatar
YiYi Xu committed
628
            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
629
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
630
631
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
632
633
        else:
            self.time_emb_proj = None
634

635
636
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
637
638
        elif self.time_embedding_norm == "spatial":
            self.norm2 = SpatialNorm(out_channels, temb_channels)
639
640
641
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

642
        self.dropout = torch.nn.Dropout(dropout)
643
        conv_2d_out_channels = conv_2d_out_channels or out_channels
644
        self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
645

646
        self.nonlinearity = get_activation(non_linearity)
647
648
649
650
651

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
652
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
653
654
655
656
657
658
659
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
660
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
661
662
663
664
665
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

666
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
667
668

        self.conv_shortcut = None
669
        if self.use_in_shortcut:
670
            self.conv_shortcut = LoRACompatibleConv(
671
672
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
673

674
    def forward(self, input_tensor, temb, scale: float = 1.0):
675
        hidden_states = input_tensor
676

YiYi Xu's avatar
YiYi Xu committed
677
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
678
679
680
681
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

682
        hidden_states = self.nonlinearity(hidden_states)
683
684

        if self.upsample is not None:
685
686
687
688
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
689
690
691
692
693
694
695
696
697
698
            input_tensor = (
                self.upsample(input_tensor, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(input_tensor)
            )
            hidden_states = (
                self.upsample(hidden_states, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(hidden_states)
            )
699
        elif self.downsample is not None:
700
701
702
703
704
705
706
707
708
709
            input_tensor = (
                self.downsample(input_tensor, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(input_tensor)
            )
            hidden_states = (
                self.downsample(hidden_states, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(hidden_states)
            )
710

711
        hidden_states = self.conv1(hidden_states, scale)
712

713
        if self.time_emb_proj is not None:
714
715
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
716
            temb = self.time_emb_proj(temb, scale)[:, :, None, None]
Will Berman's avatar
Will Berman committed
717
718

        if temb is not None and self.time_embedding_norm == "default":
719
            hidden_states = hidden_states + temb
720

YiYi Xu's avatar
YiYi Xu committed
721
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
722
723
724
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
725
726
727
728
729

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

730
        hidden_states = self.nonlinearity(hidden_states)
731

732
        hidden_states = self.dropout(hidden_states)
733
        hidden_states = self.conv2(hidden_states, scale)
734
735

        if self.conv_shortcut is not None:
736
            input_tensor = self.conv_shortcut(input_tensor, scale)
737

738
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
739

740
        return output_tensor
741

Patrick von Platen's avatar
Patrick von Platen committed
742

743
# unet_rl.py
744
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
745
746
747
748
749
750
751
752
753
754
755
756
757
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
758
759
760
761
762
763

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
        n_groups (`int`, default `8`): Number of groups to separate the channels into.
764
765
    """

766
767
768
    def __init__(
        self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
    ):
769
770
771
772
773
774
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

775
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
776
777
778
779
780
781
        intermediate_repr = self.conv1d(inputs)
        intermediate_repr = rearrange_dims(intermediate_repr)
        intermediate_repr = self.group_norm(intermediate_repr)
        intermediate_repr = rearrange_dims(intermediate_repr)
        output = self.mish(intermediate_repr)
        return output
782
783
784
785


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
786
787
788
789
790
791
792
793
794
795
796
797
798
    """
    Residual 1D block with temporal convolutions.

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        embed_dim (`int`): Embedding dimension.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
    """

    def __init__(
        self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
    ):
799
800
801
802
803
804
805
806
807
808
809
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

810
    def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
811
812
        """
        Args:
813
            inputs : [ batch_size x inp_channels x horizon ]
814
815
816
817
818
819
820
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
821
        out = self.conv_in(inputs) + rearrange_dims(t)
822
        out = self.conv_out(out)
823
        return out + self.residual_conv(inputs)
824
825


826
827
828
def upsample_2d(
    hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor:
829
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
830
831
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
832
833
834
835
836
837
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
838
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
839
840
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
841
842

    Returns:
843
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
844
845
    """
    assert isinstance(factor, int) and factor >= 1
846
847
    if kernel is None:
        kernel = [1] * factor
848

849
    kernel = torch.tensor(kernel, dtype=torch.float32)
850
    if kernel.ndim == 1:
851
852
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
853

854
    kernel = kernel * (gain * (factor**2))
855
    pad_value = kernel.shape[0] - factor
856
    output = upfirdn2d_native(
857
858
859
860
861
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
862
    return output
Patrick von Platen's avatar
Patrick von Platen committed
863
864


865
866
867
def downsample_2d(
    hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor:
868
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
869
870
871
872
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
873
874
875

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
876
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
877
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
878
879
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
880
881

    Returns:
882
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
883
884
885
    """

    assert isinstance(factor, int) and factor >= 1
886
887
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
888

889
    kernel = torch.tensor(kernel, dtype=torch.float32)
890
    if kernel.ndim == 1:
891
892
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
893

894
    kernel = kernel * gain
895
    pad_value = kernel.shape[0] - factor
896
    output = upfirdn2d_native(
897
898
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
899
    return output
900
901


902
903
904
def upfirdn2d_native(
    tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
) -> torch.Tensor:
905
906
907
908
909
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

910
911
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
912

913
    _, in_h, in_w, minor = tensor.shape
914
915
    kernel_h, kernel_w = kernel.shape

916
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
917
918
919
920
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
921
    out = out.to(tensor.device)  # Move back to mps if necessary
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
946
947
948
949
950
951


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
952
953
954
955
956

    Parameters:
        in_dim (`int`): Number of input channels.
        out_dim (`int`): Number of output channels.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
957
958
    """

959
    def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0):
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
            nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
        )
        self.conv2 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

992
    def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states