resnet.py 41.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from ..utils import USE_PEFT_BACKEND
24
from .activations import get_activation
YiYi Xu's avatar
YiYi Xu committed
25
from .attention_processor import SpatialNorm
26
from .lora import LoRACompatibleConv, LoRACompatibleLinear
27
from .normalization import AdaGroupNorm
28

29

30
class Upsample1D(nn.Module):
31
    """A 1D upsampling layer with an optional convolution.
32
33

    Parameters:
34
35
36
37
38
39
40
41
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
42
43
        name (`str`, default `conv`):
            name of the upsampling 1D layer.
44
45
    """

46
47
48
49
50
51
52
53
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
54
55
56
57
58
59
60
61
62
63
64
65
66
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

67
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
68
        assert inputs.shape[1] == self.channels
69
        if self.use_conv_transpose:
70
            return self.conv(inputs)
71

72
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
73
74

        if self.use_conv:
75
            outputs = self.conv(outputs)
76

77
        return outputs
78
79
80


class Downsample1D(nn.Module):
81
    """A 1D downsampling layer with an optional convolution.
82
83

    Parameters:
84
85
86
87
88
89
90
91
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
92
93
        name (`str`, default `conv`):
            name of the downsampling 1D layer.
94
95
    """

96
97
98
99
100
101
102
103
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
    ):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

118
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
119
120
        assert inputs.shape[1] == self.channels
        return self.conv(inputs)
121
122


123
class Upsample2D(nn.Module):
124
    """A 2D upsampling layer with an optional convolution.
125

126
    Parameters:
127
128
129
130
131
132
133
134
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
135
136
        name (`str`, default `conv`):
            name of the upsampling 2D layer.
137
138
    """

139
140
141
142
143
144
145
146
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
147
148
149
150
151
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
152
        self.name = name
153
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
154

patil-suraj's avatar
patil-suraj committed
155
        conv = None
156
        if use_conv_transpose:
157
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
158
        elif use_conv:
159
            conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
160

161
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
162
163
164
165
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
166

167
168
169
    def forward(
        self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
    ) -> torch.FloatTensor:
170
        assert hidden_states.shape[1] == self.channels
171

172
        if self.use_conv_transpose:
173
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
174

175
176
177
178
179
180
181
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

182
183
184
185
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

186
187
188
189
190
191
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
192

193
194
195
196
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

197
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
198
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
199
            if self.name == "conv":
200
                if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
201
202
203
                    hidden_states = self.conv(hidden_states, scale)
                else:
                    hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
204
            else:
205
                if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
206
207
208
                    hidden_states = self.Conv2d_0(hidden_states, scale)
                else:
                    hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
209

210
        return hidden_states
211
212


213
class Downsample2D(nn.Module):
214
    """A 2D downsampling layer with an optional convolution.
215

216
    Parameters:
217
218
219
220
221
222
223
224
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
225
226
        name (`str`, default `conv`):
            name of the downsampling 2D layer.
227
228
    """

229
230
231
232
233
234
235
236
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
    ):
237
238
239
240
241
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
242
        stride = 2
patil-suraj's avatar
patil-suraj committed
243
        self.name = name
244
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
patil-suraj's avatar
patil-suraj committed
245

246
        if use_conv:
247
            conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
248
249
        else:
            assert self.channels == self.out_channels
250
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
251

252
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
253
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
254
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
255
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
256
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
257
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
258
        else:
Patrick von Platen's avatar
Patrick von Platen committed
259
            self.conv = conv
260

261
    def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
262
        assert hidden_states.shape[1] == self.channels
263

264
        if self.use_conv and self.padding == 0:
265
            pad = (0, 1, 0, 1)
266
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
267

268
        assert hidden_states.shape[1] == self.channels
269
270
271
272
273
274

        if not USE_PEFT_BACKEND:
            if isinstance(self.conv, LoRACompatibleConv):
                hidden_states = self.conv(hidden_states, scale)
            else:
                hidden_states = self.conv(hidden_states)
275
276
        else:
            hidden_states = self.conv(hidden_states)
277

278
        return hidden_states
279
280
281


class FirUpsample2D(nn.Module):
282
283
284
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
285
        channels (`int`, optional):
286
287
288
289
290
291
292
293
294
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

295
296
    def __init__(
        self,
297
        channels: Optional[int] = None,
298
299
300
301
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
302
303
304
305
306
307
308
309
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

310
311
    def _upsample_2d(
        self,
312
313
        hidden_states: torch.FloatTensor,
        weight: Optional[torch.FloatTensor] = None,
314
315
316
        kernel: Optional[torch.FloatTensor] = None,
        factor: int = 2,
        gain: float = 1,
317
    ) -> torch.FloatTensor:
318
319
320
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
321
322
323
324
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
325
326
327
328
329
330
331
332
333
334
            hidden_states (`torch.FloatTensor`):
                Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight (`torch.FloatTensor`, *optional*):
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel (`torch.FloatTensor`, *optional*):
                FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
                corresponds to nearest-neighbor upsampling.
            factor (`int`, *optional*): Integer upsampling factor (default: 2).
            gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
335
336

        Returns:
337
338
339
            output (`torch.FloatTensor`):
                Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
                datatype as `hidden_states`.
340
341
342
343
344
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
345
346
        if kernel is None:
            kernel = [1] * factor
347
348

        # setup kernel
349
        kernel = torch.tensor(kernel, dtype=torch.float32)
350
        if kernel.ndim == 1:
351
352
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
353

354
        kernel = kernel * (gain * (factor**2))
355
356

        if self.use_conv:
357
358
359
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
360

361
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
362
363
364

            stride = (factor, factor)
            # Determine data dimensions.
365
366
367
368
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
369
            output_padding = (
370
371
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
372
373
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
374
            num_groups = hidden_states.shape[1] // inC
375
376

            # Transpose weights.
377
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
378
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
379
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
380

381
382
383
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
384

385
386
387
388
389
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
390
        else:
391
392
393
394
395
396
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
397
398
            )

399
        return output
400

401
    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
402
        if self.use_conv:
403
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
404
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
405
        else:
406
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
407

408
        return height
409
410
411


class FirDownsample2D(nn.Module):
412
413
414
415
416
417
418
419
420
421
422
423
424
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

425
426
    def __init__(
        self,
427
        channels: Optional[int] = None,
428
429
430
431
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
432
433
434
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
435
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
436
437
438
439
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

440
441
    def _downsample_2d(
        self,
442
443
        hidden_states: torch.FloatTensor,
        weight: Optional[torch.FloatTensor] = None,
444
445
446
        kernel: Optional[torch.FloatTensor] = None,
        factor: int = 2,
        gain: float = 1,
447
    ) -> torch.FloatTensor:
448
        """Fused `Conv2d()` followed by `downsample_2d()`.
449
450
451
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
452
453

        Args:
454
455
456
            hidden_states (`torch.FloatTensor`):
                Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight (`torch.FloatTensor`, *optional*):
457
458
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
459
460
461
462
463
464
465
            kernel (`torch.FloatTensor`, *optional*):
                FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
                corresponds to average pooling.
            factor (`int`, *optional*, default to `2`):
                Integer downsampling factor.
            gain (`float`, *optional*, default to `1.0`):
                Scaling factor for signal magnitude.
466
467

        Returns:
468
469
470
            output (`torch.FloatTensor`):
                Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
                datatype as `x`.
471
        """
472

473
        assert isinstance(factor, int) and factor >= 1
474
475
        if kernel is None:
            kernel = [1] * factor
476

477
        # setup kernel
478
        kernel = torch.tensor(kernel, dtype=torch.float32)
479
        if kernel.ndim == 1:
480
481
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
482

483
        kernel = kernel * gain
484

485
        if self.use_conv:
486
            _, _, convH, convW = weight.shape
487
488
489
490
491
492
493
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
494
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
495
        else:
496
            pad_value = kernel.shape[0] - factor
497
            output = upfirdn2d_native(
498
499
500
501
502
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
503

504
        return output
505

506
    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
507
        if self.use_conv:
508
509
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
510
        else:
511
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
512

513
        return hidden_states
514
515


516
517
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
518
519
520
521
522
523
524
    r"""A 2D K-downsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    def __init__(self, pad_mode: str = "reflect"):
525
526
527
528
529
530
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

531
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
532
533
534
535
        inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
536
        weight[indices, indices] = kernel
537
        return F.conv2d(inputs, weight, stride=2)
538
539
540


class KUpsample2D(nn.Module):
541
542
543
544
545
546
547
    r"""A 2D K-upsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    def __init__(self, pad_mode: str = "reflect"):
548
549
550
551
552
553
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

554
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
555
556
557
558
        inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
559
        weight[indices, indices] = kernel
560
        return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
561
562


563
class ResnetBlock2D(nn.Module):
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
581
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
582
583
584
585
586
587
588
589
590
591
592
593
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

594
595
596
    def __init__(
        self,
        *,
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        in_channels: int,
        out_channels: Optional[int] = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        temb_channels: int = 512,
        groups: int = 32,
        groups_out: Optional[int] = None,
        pre_norm: bool = True,
        eps: float = 1e-6,
        non_linearity: str = "swish",
        skip_time_act: bool = False,
        time_embedding_norm: str = "default",  # default, scale_shift, ada_group, spatial
        kernel: Optional[torch.FloatTensor] = None,
        output_scale_factor: float = 1.0,
        use_in_shortcut: Optional[bool] = None,
        up: bool = False,
        down: bool = False,
614
615
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
616
617
618
619
620
621
622
623
624
625
626
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
627
        self.time_embedding_norm = time_embedding_norm
628
        self.skip_time_act = skip_time_act
629

630
631
632
        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv

633
634
635
        if groups_out is None:
            groups_out = groups

636
637
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
638
639
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
640
641
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
642

643
        self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
644

645
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
646
            if self.time_embedding_norm == "default":
647
                self.time_emb_proj = linear_cls(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
648
            elif self.time_embedding_norm == "scale_shift":
649
                self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
YiYi Xu's avatar
YiYi Xu committed
650
            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
651
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
652
653
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
654
655
        else:
            self.time_emb_proj = None
656

657
658
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
659
660
        elif self.time_embedding_norm == "spatial":
            self.norm2 = SpatialNorm(out_channels, temb_channels)
661
662
663
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

664
        self.dropout = torch.nn.Dropout(dropout)
665
        conv_2d_out_channels = conv_2d_out_channels or out_channels
666
        self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
667

668
        self.nonlinearity = get_activation(non_linearity)
669
670
671
672
673

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
674
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
675
676
677
678
679
680
681
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
682
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
683
684
685
686
687
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

688
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
689
690

        self.conv_shortcut = None
691
        if self.use_in_shortcut:
692
            self.conv_shortcut = conv_cls(
693
694
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
695

696
697
698
    def forward(
        self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
    ) -> torch.FloatTensor:
699
        hidden_states = input_tensor
700

YiYi Xu's avatar
YiYi Xu committed
701
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
702
703
704
705
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

706
        hidden_states = self.nonlinearity(hidden_states)
707
708

        if self.upsample is not None:
709
710
711
712
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
713
714
715
716
717
718
719
720
721
722
            input_tensor = (
                self.upsample(input_tensor, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(input_tensor)
            )
            hidden_states = (
                self.upsample(hidden_states, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(hidden_states)
            )
723
        elif self.downsample is not None:
724
725
726
727
728
729
730
731
732
733
            input_tensor = (
                self.downsample(input_tensor, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(input_tensor)
            )
            hidden_states = (
                self.downsample(hidden_states, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(hidden_states)
            )
734

735
        hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
736

737
        if self.time_emb_proj is not None:
738
739
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
740
741
742
743
744
            temb = (
                self.time_emb_proj(temb, scale)[:, :, None, None]
                if not USE_PEFT_BACKEND
                else self.time_emb_proj(temb)[:, :, None, None]
            )
Will Berman's avatar
Will Berman committed
745
746

        if temb is not None and self.time_embedding_norm == "default":
747
            hidden_states = hidden_states + temb
748

YiYi Xu's avatar
YiYi Xu committed
749
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
750
751
752
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
753
754
755
756
757

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

758
        hidden_states = self.nonlinearity(hidden_states)
759

760
        hidden_states = self.dropout(hidden_states)
761
        hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
762
763

        if self.conv_shortcut is not None:
764
765
766
            input_tensor = (
                self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
            )
767

768
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
769

770
        return output_tensor
771

Patrick von Platen's avatar
Patrick von Platen committed
772

773
# unet_rl.py
774
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
775
776
777
778
779
780
781
782
783
784
785
786
787
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
788
789
790
791
792
793

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
        n_groups (`int`, default `8`): Number of groups to separate the channels into.
794
        activation (`str`, defaults to `mish`): Name of the activation function.
795
796
    """

797
    def __init__(
798
799
800
801
802
803
        self,
        inp_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        n_groups: int = 8,
        activation: str = "mish",
804
    ):
805
806
807
808
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
809
        self.mish = get_activation(activation)
810

811
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
812
813
814
815
816
817
        intermediate_repr = self.conv1d(inputs)
        intermediate_repr = rearrange_dims(intermediate_repr)
        intermediate_repr = self.group_norm(intermediate_repr)
        intermediate_repr = rearrange_dims(intermediate_repr)
        output = self.mish(intermediate_repr)
        return output
818
819
820
821


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
822
823
824
825
826
827
828
829
    """
    Residual 1D block with temporal convolutions.

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        embed_dim (`int`): Embedding dimension.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
830
        activation (`str`, defaults `mish`): It is possible to choose the right activation function.
831
832
833
    """

    def __init__(
834
835
836
837
838
839
        self,
        inp_channels: int,
        out_channels: int,
        embed_dim: int,
        kernel_size: Union[int, Tuple[int, int]] = 5,
        activation: str = "mish",
840
    ):
841
842
843
844
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

845
        self.time_emb_act = get_activation(activation)
846
847
848
849
850
851
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

852
    def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
853
854
        """
        Args:
855
            inputs : [ batch_size x inp_channels x horizon ]
856
857
858
859
860
861
862
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
863
        out = self.conv_in(inputs) + rearrange_dims(t)
864
        out = self.conv_out(out)
865
        return out + self.residual_conv(inputs)
866
867


868
def upsample_2d(
869
870
    hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.FloatTensor:
871
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
872
873
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
874
875
876
877
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
878
879
880
881
882
883
884
885
886
        hidden_states (`torch.FloatTensor`):
            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel (`torch.FloatTensor`, *optional*):
            FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
            corresponds to nearest-neighbor upsampling.
        factor (`int`, *optional*, default to `2`):
            Integer upsampling factor.
        gain (`float`, *optional*, default to `1.0`):
            Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
887
888

    Returns:
889
890
        output (`torch.FloatTensor`):
            Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
891
892
    """
    assert isinstance(factor, int) and factor >= 1
893
894
    if kernel is None:
        kernel = [1] * factor
895

896
    kernel = torch.tensor(kernel, dtype=torch.float32)
897
    if kernel.ndim == 1:
898
899
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
900

901
    kernel = kernel * (gain * (factor**2))
902
    pad_value = kernel.shape[0] - factor
903
    output = upfirdn2d_native(
904
905
906
907
908
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
909
    return output
Patrick von Platen's avatar
Patrick von Platen committed
910
911


912
def downsample_2d(
913
914
    hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.FloatTensor:
915
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
916
917
918
919
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
920
921

    Args:
922
923
924
925
926
927
928
929
930
        hidden_states (`torch.FloatTensor`)
            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel (`torch.FloatTensor`, *optional*):
            FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
            corresponds to average pooling.
        factor (`int`, *optional*, default to `2`):
            Integer downsampling factor.
        gain (`float`, *optional*, default to `1.0`):
            Scaling factor for signal magnitude.
Patrick von Platen's avatar
Patrick von Platen committed
931
932

    Returns:
933
934
        output (`torch.FloatTensor`):
            Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
935
936
937
    """

    assert isinstance(factor, int) and factor >= 1
938
939
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
940

941
    kernel = torch.tensor(kernel, dtype=torch.float32)
942
    if kernel.ndim == 1:
943
944
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
945

946
    kernel = kernel * gain
947
    pad_value = kernel.shape[0] - factor
948
    output = upfirdn2d_native(
949
950
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
951
    return output
952
953


954
955
956
def upfirdn2d_native(
    tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
) -> torch.Tensor:
957
958
959
960
961
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

962
963
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
964

965
    _, in_h, in_w, minor = tensor.shape
966
967
    kernel_h, kernel_w = kernel.shape

968
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
969
970
971
972
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
973
    out = out.to(tensor.device)  # Move back to mps if necessary
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
998
999
1000
1001
1002
1003


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
1004
1005
1006
1007
1008

    Parameters:
        in_dim (`int`): Number of input channels.
        out_dim (`int`): Number of output channels.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
1009
1010
    """

Dhruv Nair's avatar
Dhruv Nair committed
1011
    def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
1012
1013
1014
1015
1016
1017
1018
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1019
            nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
1020
1021
        )
        self.conv2 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1022
            nn.GroupNorm(norm_num_groups, out_dim),
1023
1024
1025
1026
1027
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1028
            nn.GroupNorm(norm_num_groups, out_dim),
1029
1030
1031
1032
1033
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1034
            nn.GroupNorm(norm_num_groups, out_dim),
1035
1036
1037
1038
1039
1040
1041
1042
1043
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

1044
    def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states