resnet.py 51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from ..utils import USE_PEFT_BACKEND
24
from .activations import get_activation
YiYi Xu's avatar
YiYi Xu committed
25
from .attention_processor import SpatialNorm
26
from .lora import LoRACompatibleConv, LoRACompatibleLinear
27
from .normalization import AdaGroupNorm
28

29

30
class Upsample1D(nn.Module):
31
    """A 1D upsampling layer with an optional convolution.
32
33

    Parameters:
34
35
36
37
38
39
40
41
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
42
43
        name (`str`, default `conv`):
            name of the upsampling 1D layer.
44
45
    """

46
47
48
49
50
51
52
53
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
54
55
56
57
58
59
60
61
62
63
64
65
66
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

67
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
68
        assert inputs.shape[1] == self.channels
69
        if self.use_conv_transpose:
70
            return self.conv(inputs)
71

72
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
73
74

        if self.use_conv:
75
            outputs = self.conv(outputs)
76

77
        return outputs
78
79
80


class Downsample1D(nn.Module):
81
    """A 1D downsampling layer with an optional convolution.
82
83

    Parameters:
84
85
86
87
88
89
90
91
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
92
93
        name (`str`, default `conv`):
            name of the downsampling 1D layer.
94
95
    """

96
97
98
99
100
101
102
103
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
    ):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

118
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
119
120
        assert inputs.shape[1] == self.channels
        return self.conv(inputs)
121
122


123
class Upsample2D(nn.Module):
124
    """A 2D upsampling layer with an optional convolution.
125

126
    Parameters:
127
128
129
130
131
132
133
134
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
135
136
        name (`str`, default `conv`):
            name of the upsampling 2D layer.
137
138
    """

139
140
141
142
143
144
145
146
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
147
148
149
150
151
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
152
        self.name = name
153
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
154

patil-suraj's avatar
patil-suraj committed
155
        conv = None
156
        if use_conv_transpose:
157
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
158
        elif use_conv:
159
            conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
160

161
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
162
163
164
165
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
166

167
    def forward(
Suraj Patil's avatar
Suraj Patil committed
168
169
170
171
        self,
        hidden_states: torch.FloatTensor,
        output_size: Optional[int] = None,
        scale: float = 1.0,
172
    ) -> torch.FloatTensor:
173
        assert hidden_states.shape[1] == self.channels
174

175
        if self.use_conv_transpose:
176
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
177

178
179
180
181
182
183
184
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

185
186
187
188
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

189
190
191
192
193
194
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
195

196
197
198
199
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

200
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
201
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
202
            if self.name == "conv":
203
                if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
204
205
206
                    hidden_states = self.conv(hidden_states, scale)
                else:
                    hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
207
            else:
208
                if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
209
210
211
                    hidden_states = self.Conv2d_0(hidden_states, scale)
                else:
                    hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
212

213
        return hidden_states
214
215


216
class Downsample2D(nn.Module):
217
    """A 2D downsampling layer with an optional convolution.
218

219
    Parameters:
220
221
222
223
224
225
226
227
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
228
229
        name (`str`, default `conv`):
            name of the downsampling 2D layer.
230
231
    """

232
233
234
235
236
237
238
239
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
    ):
240
241
242
243
244
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
245
        stride = 2
patil-suraj's avatar
patil-suraj committed
246
        self.name = name
247
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
patil-suraj's avatar
patil-suraj committed
248

249
        if use_conv:
250
            conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
251
252
        else:
            assert self.channels == self.out_channels
253
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
254

255
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
256
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
257
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
258
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
259
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
260
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
261
        else:
Patrick von Platen's avatar
Patrick von Platen committed
262
            self.conv = conv
263

264
    def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
265
        assert hidden_states.shape[1] == self.channels
266

267
        if self.use_conv and self.padding == 0:
268
            pad = (0, 1, 0, 1)
269
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
270

271
        assert hidden_states.shape[1] == self.channels
272
273
274
275
276
277

        if not USE_PEFT_BACKEND:
            if isinstance(self.conv, LoRACompatibleConv):
                hidden_states = self.conv(hidden_states, scale)
            else:
                hidden_states = self.conv(hidden_states)
278
279
        else:
            hidden_states = self.conv(hidden_states)
280

281
        return hidden_states
282
283
284


class FirUpsample2D(nn.Module):
285
286
287
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
288
        channels (`int`, optional):
289
290
291
292
293
294
295
296
297
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

298
299
    def __init__(
        self,
300
        channels: Optional[int] = None,
301
302
303
304
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
305
306
307
308
309
310
311
312
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

313
314
    def _upsample_2d(
        self,
315
316
        hidden_states: torch.FloatTensor,
        weight: Optional[torch.FloatTensor] = None,
317
318
319
        kernel: Optional[torch.FloatTensor] = None,
        factor: int = 2,
        gain: float = 1,
320
    ) -> torch.FloatTensor:
321
322
323
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
324
325
326
327
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
328
329
330
331
332
333
334
335
336
337
            hidden_states (`torch.FloatTensor`):
                Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight (`torch.FloatTensor`, *optional*):
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel (`torch.FloatTensor`, *optional*):
                FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
                corresponds to nearest-neighbor upsampling.
            factor (`int`, *optional*): Integer upsampling factor (default: 2).
            gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
338
339

        Returns:
340
341
342
            output (`torch.FloatTensor`):
                Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
                datatype as `hidden_states`.
343
344
345
346
347
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
348
349
        if kernel is None:
            kernel = [1] * factor
350
351

        # setup kernel
352
        kernel = torch.tensor(kernel, dtype=torch.float32)
353
        if kernel.ndim == 1:
354
355
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
356

357
        kernel = kernel * (gain * (factor**2))
358
359

        if self.use_conv:
360
361
362
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
363

364
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
365
366
367

            stride = (factor, factor)
            # Determine data dimensions.
368
369
370
371
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
372
            output_padding = (
373
374
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
375
376
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
377
            num_groups = hidden_states.shape[1] // inC
378
379

            # Transpose weights.
380
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
381
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
382
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
383

384
            inverse_conv = F.conv_transpose2d(
Suraj Patil's avatar
Suraj Patil committed
385
386
387
388
389
                hidden_states,
                weight,
                stride=stride,
                output_padding=output_padding,
                padding=0,
390
            )
391

392
393
394
395
396
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
397
        else:
398
399
400
401
402
403
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
404
405
            )

406
        return output
407

408
    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
409
        if self.use_conv:
410
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
411
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
412
        else:
413
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
414

415
        return height
416
417
418


class FirDownsample2D(nn.Module):
419
420
421
422
423
424
425
426
427
428
429
430
431
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

432
433
    def __init__(
        self,
434
        channels: Optional[int] = None,
435
436
437
438
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
439
440
441
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
442
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
443
444
445
446
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

447
448
    def _downsample_2d(
        self,
449
450
        hidden_states: torch.FloatTensor,
        weight: Optional[torch.FloatTensor] = None,
451
452
453
        kernel: Optional[torch.FloatTensor] = None,
        factor: int = 2,
        gain: float = 1,
454
    ) -> torch.FloatTensor:
455
        """Fused `Conv2d()` followed by `downsample_2d()`.
456
457
458
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
459
460

        Args:
461
462
463
            hidden_states (`torch.FloatTensor`):
                Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight (`torch.FloatTensor`, *optional*):
464
465
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
466
467
468
469
470
471
472
            kernel (`torch.FloatTensor`, *optional*):
                FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
                corresponds to average pooling.
            factor (`int`, *optional*, default to `2`):
                Integer downsampling factor.
            gain (`float`, *optional*, default to `1.0`):
                Scaling factor for signal magnitude.
473
474

        Returns:
475
476
477
            output (`torch.FloatTensor`):
                Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
                datatype as `x`.
478
        """
479

480
        assert isinstance(factor, int) and factor >= 1
481
482
        if kernel is None:
            kernel = [1] * factor
483

484
        # setup kernel
485
        kernel = torch.tensor(kernel, dtype=torch.float32)
486
        if kernel.ndim == 1:
487
488
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
489

490
        kernel = kernel * gain
491

492
        if self.use_conv:
493
            _, _, convH, convW = weight.shape
494
495
496
497
498
499
500
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
501
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
502
        else:
503
            pad_value = kernel.shape[0] - factor
504
            output = upfirdn2d_native(
505
506
507
508
509
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
510

511
        return output
512

513
    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
514
        if self.use_conv:
515
516
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
517
        else:
518
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
519

520
        return hidden_states
521
522


523
524
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
525
526
527
528
529
530
531
    r"""A 2D K-downsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    def __init__(self, pad_mode: str = "reflect"):
532
533
534
535
536
537
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

538
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
539
        inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
Suraj Patil's avatar
Suraj Patil committed
540
541
542
543
544
545
546
547
        weight = inputs.new_zeros(
            [
                inputs.shape[1],
                inputs.shape[1],
                self.kernel.shape[0],
                self.kernel.shape[1],
            ]
        )
548
549
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
550
        weight[indices, indices] = kernel
551
        return F.conv2d(inputs, weight, stride=2)
552
553
554


class KUpsample2D(nn.Module):
555
556
557
558
559
560
561
    r"""A 2D K-upsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    def __init__(self, pad_mode: str = "reflect"):
562
563
564
565
566
567
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

568
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
569
        inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
Suraj Patil's avatar
Suraj Patil committed
570
571
572
573
574
575
576
577
        weight = inputs.new_zeros(
            [
                inputs.shape[1],
                inputs.shape[1],
                self.kernel.shape[0],
                self.kernel.shape[1],
            ]
        )
578
579
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
580
        weight[indices, indices] = kernel
581
        return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
582
583


584
class ResnetBlock2D(nn.Module):
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
602
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
603
604
605
606
607
608
609
610
611
612
613
614
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

615
616
617
    def __init__(
        self,
        *,
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        in_channels: int,
        out_channels: Optional[int] = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        temb_channels: int = 512,
        groups: int = 32,
        groups_out: Optional[int] = None,
        pre_norm: bool = True,
        eps: float = 1e-6,
        non_linearity: str = "swish",
        skip_time_act: bool = False,
        time_embedding_norm: str = "default",  # default, scale_shift, ada_group, spatial
        kernel: Optional[torch.FloatTensor] = None,
        output_scale_factor: float = 1.0,
        use_in_shortcut: Optional[bool] = None,
        up: bool = False,
        down: bool = False,
635
636
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
637
638
639
640
641
642
643
644
645
646
647
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
648
        self.time_embedding_norm = time_embedding_norm
649
        self.skip_time_act = skip_time_act
650

651
652
653
        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv

654
655
656
        if groups_out is None:
            groups_out = groups

657
658
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
659
660
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
661
662
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
663

664
        self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
665

666
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
667
            if self.time_embedding_norm == "default":
668
                self.time_emb_proj = linear_cls(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
669
            elif self.time_embedding_norm == "scale_shift":
670
                self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
YiYi Xu's avatar
YiYi Xu committed
671
            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
672
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
673
674
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
675
676
        else:
            self.time_emb_proj = None
677

678
679
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
680
681
        elif self.time_embedding_norm == "spatial":
            self.norm2 = SpatialNorm(out_channels, temb_channels)
682
683
684
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

685
        self.dropout = torch.nn.Dropout(dropout)
686
        conv_2d_out_channels = conv_2d_out_channels or out_channels
687
        self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
688

689
        self.nonlinearity = get_activation(non_linearity)
690
691
692
693
694

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
695
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
696
697
698
699
700
701
702
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
703
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
704
705
706
707
708
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

709
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
710
711

        self.conv_shortcut = None
712
        if self.use_in_shortcut:
713
            self.conv_shortcut = conv_cls(
Suraj Patil's avatar
Suraj Patil committed
714
715
716
717
718
719
                in_channels,
                conv_2d_out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=conv_shortcut_bias,
720
            )
721

722
    def forward(
Suraj Patil's avatar
Suraj Patil committed
723
724
725
726
        self,
        input_tensor: torch.FloatTensor,
        temb: torch.FloatTensor,
        scale: float = 1.0,
727
    ) -> torch.FloatTensor:
728
        hidden_states = input_tensor
729

YiYi Xu's avatar
YiYi Xu committed
730
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
731
732
733
734
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

735
        hidden_states = self.nonlinearity(hidden_states)
736
737

        if self.upsample is not None:
738
739
740
741
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
742
743
744
745
746
747
748
749
750
751
            input_tensor = (
                self.upsample(input_tensor, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(input_tensor)
            )
            hidden_states = (
                self.upsample(hidden_states, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(hidden_states)
            )
752
        elif self.downsample is not None:
753
754
755
756
757
758
759
760
761
762
            input_tensor = (
                self.downsample(input_tensor, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(input_tensor)
            )
            hidden_states = (
                self.downsample(hidden_states, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(hidden_states)
            )
763

764
        hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
765

766
        if self.time_emb_proj is not None:
767
768
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
769
770
771
772
773
            temb = (
                self.time_emb_proj(temb, scale)[:, :, None, None]
                if not USE_PEFT_BACKEND
                else self.time_emb_proj(temb)[:, :, None, None]
            )
Will Berman's avatar
Will Berman committed
774
775

        if temb is not None and self.time_embedding_norm == "default":
776
            hidden_states = hidden_states + temb
777

YiYi Xu's avatar
YiYi Xu committed
778
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
779
780
781
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
782
783
784
785
786

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

787
        hidden_states = self.nonlinearity(hidden_states)
788

789
        hidden_states = self.dropout(hidden_states)
790
        hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
791
792

        if self.conv_shortcut is not None:
793
794
795
            input_tensor = (
                self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
            )
796

797
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
798

799
        return output_tensor
800

Patrick von Platen's avatar
Patrick von Platen committed
801

802
# unet_rl.py
803
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
804
805
806
807
808
809
810
811
812
813
814
815
816
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
817
818
819
820
821
822

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
        n_groups (`int`, default `8`): Number of groups to separate the channels into.
823
        activation (`str`, defaults to `mish`): Name of the activation function.
824
825
    """

826
    def __init__(
827
828
829
830
831
832
        self,
        inp_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        n_groups: int = 8,
        activation: str = "mish",
833
    ):
834
835
836
837
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
838
        self.mish = get_activation(activation)
839

840
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
841
842
843
844
845
846
        intermediate_repr = self.conv1d(inputs)
        intermediate_repr = rearrange_dims(intermediate_repr)
        intermediate_repr = self.group_norm(intermediate_repr)
        intermediate_repr = rearrange_dims(intermediate_repr)
        output = self.mish(intermediate_repr)
        return output
847
848
849
850


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
851
852
853
854
855
856
857
858
    """
    Residual 1D block with temporal convolutions.

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        embed_dim (`int`): Embedding dimension.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
859
        activation (`str`, defaults `mish`): It is possible to choose the right activation function.
860
861
862
    """

    def __init__(
863
864
865
866
867
868
        self,
        inp_channels: int,
        out_channels: int,
        embed_dim: int,
        kernel_size: Union[int, Tuple[int, int]] = 5,
        activation: str = "mish",
869
    ):
870
871
872
873
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

874
        self.time_emb_act = get_activation(activation)
875
876
877
878
879
880
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

881
    def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
882
883
        """
        Args:
884
            inputs : [ batch_size x inp_channels x horizon ]
885
886
887
888
889
890
891
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
892
        out = self.conv_in(inputs) + rearrange_dims(t)
893
        out = self.conv_out(out)
894
        return out + self.residual_conv(inputs)
895
896


897
def upsample_2d(
Suraj Patil's avatar
Suraj Patil committed
898
899
900
901
    hidden_states: torch.FloatTensor,
    kernel: Optional[torch.FloatTensor] = None,
    factor: int = 2,
    gain: float = 1,
902
) -> torch.FloatTensor:
903
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
904
905
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
906
907
908
909
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
910
911
912
913
914
915
916
917
918
        hidden_states (`torch.FloatTensor`):
            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel (`torch.FloatTensor`, *optional*):
            FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
            corresponds to nearest-neighbor upsampling.
        factor (`int`, *optional*, default to `2`):
            Integer upsampling factor.
        gain (`float`, *optional*, default to `1.0`):
            Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
919
920

    Returns:
921
922
        output (`torch.FloatTensor`):
            Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
923
924
    """
    assert isinstance(factor, int) and factor >= 1
925
926
    if kernel is None:
        kernel = [1] * factor
927

928
    kernel = torch.tensor(kernel, dtype=torch.float32)
929
    if kernel.ndim == 1:
930
931
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
932

933
    kernel = kernel * (gain * (factor**2))
934
    pad_value = kernel.shape[0] - factor
935
    output = upfirdn2d_native(
936
937
938
939
940
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
941
    return output
Patrick von Platen's avatar
Patrick von Platen committed
942
943


944
def downsample_2d(
Suraj Patil's avatar
Suraj Patil committed
945
946
947
948
    hidden_states: torch.FloatTensor,
    kernel: Optional[torch.FloatTensor] = None,
    factor: int = 2,
    gain: float = 1,
949
) -> torch.FloatTensor:
950
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
951
952
953
954
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
955
956

    Args:
957
958
959
960
961
962
963
964
965
        hidden_states (`torch.FloatTensor`)
            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel (`torch.FloatTensor`, *optional*):
            FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
            corresponds to average pooling.
        factor (`int`, *optional*, default to `2`):
            Integer downsampling factor.
        gain (`float`, *optional*, default to `1.0`):
            Scaling factor for signal magnitude.
Patrick von Platen's avatar
Patrick von Platen committed
966
967

    Returns:
968
969
        output (`torch.FloatTensor`):
            Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
970
971
972
    """

    assert isinstance(factor, int) and factor >= 1
973
974
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
975

976
    kernel = torch.tensor(kernel, dtype=torch.float32)
977
    if kernel.ndim == 1:
978
979
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
980

981
    kernel = kernel * gain
982
    pad_value = kernel.shape[0] - factor
983
    output = upfirdn2d_native(
Suraj Patil's avatar
Suraj Patil committed
984
985
986
987
        hidden_states,
        kernel.to(device=hidden_states.device),
        down=factor,
        pad=((pad_value + 1) // 2, pad_value // 2),
988
    )
989
    return output
990
991


992
def upfirdn2d_native(
Suraj Patil's avatar
Suraj Patil committed
993
994
995
996
997
    tensor: torch.Tensor,
    kernel: torch.Tensor,
    up: int = 1,
    down: int = 1,
    pad: Tuple[int, int] = (0, 0),
998
) -> torch.Tensor:
999
1000
1001
1002
1003
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

1004
1005
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
1006

1007
    _, in_h, in_w, minor = tensor.shape
1008
1009
    kernel_h, kernel_w = kernel.shape

1010
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
1011
1012
1013
1014
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
1015
    out = out.to(tensor.device)  # Move back to mps if necessary
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
1040
1041
1042
1043
1044
1045


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
1046
1047
1048
1049
1050

    Parameters:
        in_dim (`int`): Number of input channels.
        out_dim (`int`): Number of output channels.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
1051
1052
    """

Suraj Patil's avatar
Suraj Patil committed
1053
1054
1055
1056
1057
1058
1059
    def __init__(
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
    ):
1060
1061
1062
1063
1064
1065
1066
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
Suraj Patil's avatar
Suraj Patil committed
1067
1068
1069
            nn.GroupNorm(norm_num_groups, in_dim),
            nn.SiLU(),
            nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
1070
1071
        )
        self.conv2 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1072
            nn.GroupNorm(norm_num_groups, out_dim),
1073
1074
1075
1076
1077
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1078
            nn.GroupNorm(norm_num_groups, out_dim),
1079
1080
1081
1082
1083
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1084
            nn.GroupNorm(norm_num_groups, out_dim),
1085
1086
1087
1088
1089
1090
1091
1092
1093
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

1094
    def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states
Suraj Patil's avatar
Suraj Patil committed
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368


class TemporalResnetBlock(nn.Module):
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        temb_channels: int = 512,
        eps: float = 1e-6,
    ):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels

        kernel_size = (3, 1, 1)
        padding = [k // 2 for k in kernel_size]

        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
        self.conv1 = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
        )

        if temb_channels is not None:
            self.time_emb_proj = nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None

        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)

        self.dropout = torch.nn.Dropout(0.0)
        self.conv2 = nn.Conv3d(
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
        )

        self.nonlinearity = get_activation("silu")

        self.use_in_shortcut = self.in_channels != out_channels

        self.conv_shortcut = None
        if self.use_in_shortcut:
            self.conv_shortcut = nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
            )

    def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
        hidden_states = input_tensor

        hidden_states = self.norm1(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.conv1(hidden_states)

        if self.time_emb_proj is not None:
            temb = self.nonlinearity(temb)
            temb = self.time_emb_proj(temb)[:, :, :, None, None]
            temb = temb.permute(0, 2, 1, 3, 4)
            hidden_states = hidden_states + temb

        hidden_states = self.norm2(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.conv_shortcut is not None:
            input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = input_tensor + hidden_states

        return output_tensor


# VideoResBlock
class SpatioTemporalResBlock(nn.Module):
    r"""
    A SpatioTemporal Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
        temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
        merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
            The merge strategy to use for the temporal mixing.
        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
            If `True`, switch the spatial and temporal mixing.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        temb_channels: int = 512,
        eps: float = 1e-6,
        temporal_eps: Optional[float] = None,
        merge_factor: float = 0.5,
        merge_strategy="learned_with_images",
        switch_spatial_to_temporal_mix: bool = False,
    ):
        super().__init__()

        self.spatial_res_block = ResnetBlock2D(
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            eps=eps,
        )

        self.temporal_res_block = TemporalResnetBlock(
            in_channels=out_channels if out_channels is not None else in_channels,
            out_channels=out_channels if out_channels is not None else in_channels,
            temb_channels=temb_channels,
            eps=temporal_eps if temporal_eps is not None else eps,
        )

        self.time_mixer = AlphaBlender(
            alpha=merge_factor,
            merge_strategy=merge_strategy,
            switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
        )

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        temb: Optional[torch.FloatTensor] = None,
        image_only_indicator: Optional[torch.Tensor] = None,
    ):
        num_frames = image_only_indicator.shape[-1]
        hidden_states = self.spatial_res_block(hidden_states, temb)

        batch_frames, channels, height, width = hidden_states.shape
        batch_size = batch_frames // num_frames

        hidden_states_mix = (
            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
        )
        hidden_states = (
            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
        )

        if temb is not None:
            temb = temb.reshape(batch_size, num_frames, -1)

        hidden_states = self.temporal_res_block(hidden_states, temb)
        hidden_states = self.time_mixer(
            x_spatial=hidden_states_mix,
            x_temporal=hidden_states,
            image_only_indicator=image_only_indicator,
        )

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
        return hidden_states


class AlphaBlender(nn.Module):
    r"""
    A module to blend spatial and temporal features.

    Parameters:
        alpha (`float`): The initial value of the blending factor.
        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
            The merge strategy to use for the temporal mixing.
        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
            If `True`, switch the spatial and temporal mixing.
    """

    strategies = ["learned", "fixed", "learned_with_images"]

    def __init__(
        self,
        alpha: float,
        merge_strategy: str = "learned_with_images",
        switch_spatial_to_temporal_mix: bool = False,
    ):
        super().__init__()
        self.merge_strategy = merge_strategy
        self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix  # For TemporalVAE

        if merge_strategy not in self.strategies:
            raise ValueError(f"merge_strategy needs to be in {self.strategies}")

        if self.merge_strategy == "fixed":
            self.register_buffer("mix_factor", torch.Tensor([alpha]))
        elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
            self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
        else:
            raise ValueError(f"Unknown merge strategy {self.merge_strategy}")

    def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
        if self.merge_strategy == "fixed":
            alpha = self.mix_factor

        elif self.merge_strategy == "learned":
            alpha = torch.sigmoid(self.mix_factor)

        elif self.merge_strategy == "learned_with_images":
            if image_only_indicator is None:
                raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")

            alpha = torch.where(
                image_only_indicator.bool(),
                torch.ones(1, 1, device=image_only_indicator.device),
                torch.sigmoid(self.mix_factor)[..., None],
            )

            # (batch, channel, frames, height, width)
            if ndims == 5:
                alpha = alpha[:, None, :, None, None]
            # (batch*frames, height*width, channels)
            elif ndims == 3:
                alpha = alpha.reshape(-1)[:, None, None]
            else:
                raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")

        else:
            raise NotImplementedError

        return alpha

    def forward(
        self,
        x_spatial: torch.Tensor,
        x_temporal: torch.Tensor,
        image_only_indicator: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
        alpha = alpha.to(x_spatial.dtype)

        if self.switch_spatial_to_temporal_mix:
            alpha = 1.0 - alpha

        x = alpha * x_spatial + (1.0 - alpha) * x_temporal
        return x