resnet.py 40.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from ..utils import USE_PEFT_BACKEND
24
from .activations import get_activation
YiYi Xu's avatar
YiYi Xu committed
25
from .attention_processor import SpatialNorm
26
from .lora import LoRACompatibleConv, LoRACompatibleLinear
27
from .normalization import AdaGroupNorm
28

29

30
class Upsample1D(nn.Module):
31
    """A 1D upsampling layer with an optional convolution.
32
33

    Parameters:
34
35
36
37
38
39
40
41
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
42
43
        name (`str`, default `conv`):
            name of the upsampling 1D layer.
44
45
    """

46
47
48
49
50
51
52
53
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
54
55
56
57
58
59
60
61
62
63
64
65
66
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

67
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
68
        assert inputs.shape[1] == self.channels
69
        if self.use_conv_transpose:
70
            return self.conv(inputs)
71

72
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
73
74

        if self.use_conv:
75
            outputs = self.conv(outputs)
76

77
        return outputs
78
79
80


class Downsample1D(nn.Module):
81
    """A 1D downsampling layer with an optional convolution.
82
83

    Parameters:
84
85
86
87
88
89
90
91
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
92
93
        name (`str`, default `conv`):
            name of the downsampling 1D layer.
94
95
    """

96
97
98
99
100
101
102
103
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
    ):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

118
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
119
120
        assert inputs.shape[1] == self.channels
        return self.conv(inputs)
121
122


123
class Upsample2D(nn.Module):
124
    """A 2D upsampling layer with an optional convolution.
125

126
    Parameters:
127
128
129
130
131
132
133
134
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
135
136
        name (`str`, default `conv`):
            name of the upsampling 2D layer.
137
138
    """

139
140
141
142
143
144
145
146
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
147
148
149
150
151
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
152
        self.name = name
153
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
154

patil-suraj's avatar
patil-suraj committed
155
        conv = None
156
        if use_conv_transpose:
157
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
158
        elif use_conv:
159
            conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
160

161
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
162
163
164
165
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
166

167
    def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
168
        assert hidden_states.shape[1] == self.channels
169

170
        if self.use_conv_transpose:
171
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
172

173
174
175
176
177
178
179
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

180
181
182
183
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

184
185
186
187
188
189
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
190

191
192
193
194
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

195
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
196
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
197
            if self.name == "conv":
198
                if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
199
200
201
                    hidden_states = self.conv(hidden_states, scale)
                else:
                    hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
202
            else:
203
                if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
204
205
206
                    hidden_states = self.Conv2d_0(hidden_states, scale)
                else:
                    hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
207

208
        return hidden_states
209
210


211
class Downsample2D(nn.Module):
212
    """A 2D downsampling layer with an optional convolution.
213

214
    Parameters:
215
216
217
218
219
220
221
222
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
223
224
        name (`str`, default `conv`):
            name of the downsampling 2D layer.
225
226
    """

227
228
229
230
231
232
233
234
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
    ):
235
236
237
238
239
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
240
        stride = 2
patil-suraj's avatar
patil-suraj committed
241
        self.name = name
242
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
patil-suraj's avatar
patil-suraj committed
243

244
        if use_conv:
245
            conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
246
247
        else:
            assert self.channels == self.out_channels
248
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
249

250
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
251
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
252
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
253
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
254
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
255
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
256
        else:
Patrick von Platen's avatar
Patrick von Platen committed
257
            self.conv = conv
258

259
    def forward(self, hidden_states, scale: float = 1.0):
260
        assert hidden_states.shape[1] == self.channels
261

262
        if self.use_conv and self.padding == 0:
263
            pad = (0, 1, 0, 1)
264
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
265

266
        assert hidden_states.shape[1] == self.channels
267
268
269
270
271
272

        if not USE_PEFT_BACKEND:
            if isinstance(self.conv, LoRACompatibleConv):
                hidden_states = self.conv(hidden_states, scale)
            else:
                hidden_states = self.conv(hidden_states)
273
274
        else:
            hidden_states = self.conv(hidden_states)
275

276
        return hidden_states
277
278
279


class FirUpsample2D(nn.Module):
280
281
282
283
284
285
286
287
288
289
290
291
292
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

293
294
295
296
297
298
299
    def __init__(
        self,
        channels: int = None,
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
300
301
302
303
304
305
306
307
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

308
309
310
311
312
313
314
315
    def _upsample_2d(
        self,
        hidden_states: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        kernel: Optional[torch.FloatTensor] = None,
        factor: int = 2,
        gain: float = 1,
    ) -> torch.Tensor:
316
317
318
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
319
320
321
322
323
324
325
326
327
328
329
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
330
331

        Returns:
332
333
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
334
335
336
337
338
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
339
340
        if kernel is None:
            kernel = [1] * factor
341
342

        # setup kernel
343
        kernel = torch.tensor(kernel, dtype=torch.float32)
344
        if kernel.ndim == 1:
345
346
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
347

348
        kernel = kernel * (gain * (factor**2))
349
350

        if self.use_conv:
351
352
353
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
354

355
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
356
357
358

            stride = (factor, factor)
            # Determine data dimensions.
359
360
361
362
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
363
            output_padding = (
364
365
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
366
367
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
368
            num_groups = hidden_states.shape[1] // inC
369
370

            # Transpose weights.
371
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
372
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
373
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
374

375
376
377
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
378

379
380
381
382
383
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
384
        else:
385
386
387
388
389
390
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
391
392
            )

393
        return output
394

395
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
396
        if self.use_conv:
397
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
398
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
399
        else:
400
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
401

402
        return height
403
404
405


class FirDownsample2D(nn.Module):
406
407
408
409
410
411
412
413
414
415
416
417
418
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

419
420
421
422
423
424
425
    def __init__(
        self,
        channels: int = None,
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
426
427
428
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
429
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
430
431
432
433
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

434
435
436
437
438
439
440
441
    def _downsample_2d(
        self,
        hidden_states: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        kernel: Optional[torch.FloatTensor] = None,
        factor: int = 2,
        gain: float = 1,
    ) -> torch.Tensor:
442
        """Fused `Conv2d()` followed by `downsample_2d()`.
443
444
445
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
446
447

        Args:
448
449
450
451
452
453
454
455
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
456
457

        Returns:
458
459
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
460
        """
461

462
        assert isinstance(factor, int) and factor >= 1
463
464
        if kernel is None:
            kernel = [1] * factor
465

466
        # setup kernel
467
        kernel = torch.tensor(kernel, dtype=torch.float32)
468
        if kernel.ndim == 1:
469
470
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
471

472
        kernel = kernel * gain
473

474
        if self.use_conv:
475
            _, _, convH, convW = weight.shape
476
477
478
479
480
481
482
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
483
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
484
        else:
485
            pad_value = kernel.shape[0] - factor
486
            output = upfirdn2d_native(
487
488
489
490
491
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
492

493
        return output
494

495
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
496
        if self.use_conv:
497
498
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
499
        else:
500
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
501

502
        return hidden_states
503
504


505
506
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
507
508
509
510
511
512
513
    r"""A 2D K-downsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    def __init__(self, pad_mode: str = "reflect"):
514
515
516
517
518
519
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

520
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
521
522
523
524
        inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
525
        weight[indices, indices] = kernel
526
        return F.conv2d(inputs, weight, stride=2)
527
528
529


class KUpsample2D(nn.Module):
530
531
532
533
534
535
536
    r"""A 2D K-upsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    def __init__(self, pad_mode: str = "reflect"):
537
538
539
540
541
542
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

543
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
544
545
546
547
        inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
548
        weight[indices, indices] = kernel
549
        return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
550
551


552
class ResnetBlock2D(nn.Module):
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
570
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
571
572
573
574
575
576
577
578
579
580
581
582
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

583
584
585
    def __init__(
        self,
        *,
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        in_channels: int,
        out_channels: Optional[int] = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        temb_channels: int = 512,
        groups: int = 32,
        groups_out: Optional[int] = None,
        pre_norm: bool = True,
        eps: float = 1e-6,
        non_linearity: str = "swish",
        skip_time_act: bool = False,
        time_embedding_norm: str = "default",  # default, scale_shift, ada_group, spatial
        kernel: Optional[torch.FloatTensor] = None,
        output_scale_factor: float = 1.0,
        use_in_shortcut: Optional[bool] = None,
        up: bool = False,
        down: bool = False,
603
604
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
605
606
607
608
609
610
611
612
613
614
615
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
616
        self.time_embedding_norm = time_embedding_norm
617
        self.skip_time_act = skip_time_act
618

619
620
621
        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv

622
623
624
        if groups_out is None:
            groups_out = groups

625
626
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
627
628
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
629
630
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
631

632
        self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
633

634
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
635
            if self.time_embedding_norm == "default":
636
                self.time_emb_proj = linear_cls(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
637
            elif self.time_embedding_norm == "scale_shift":
638
                self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
YiYi Xu's avatar
YiYi Xu committed
639
            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
640
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
641
642
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
643
644
        else:
            self.time_emb_proj = None
645

646
647
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
648
649
        elif self.time_embedding_norm == "spatial":
            self.norm2 = SpatialNorm(out_channels, temb_channels)
650
651
652
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

653
        self.dropout = torch.nn.Dropout(dropout)
654
        conv_2d_out_channels = conv_2d_out_channels or out_channels
655
        self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
656

657
        self.nonlinearity = get_activation(non_linearity)
658
659
660
661
662

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
663
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
664
665
666
667
668
669
670
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
671
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
672
673
674
675
676
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

677
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
678
679

        self.conv_shortcut = None
680
        if self.use_in_shortcut:
681
            self.conv_shortcut = conv_cls(
682
683
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
684

685
    def forward(self, input_tensor, temb, scale: float = 1.0):
686
        hidden_states = input_tensor
687

YiYi Xu's avatar
YiYi Xu committed
688
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
689
690
691
692
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

693
        hidden_states = self.nonlinearity(hidden_states)
694
695

        if self.upsample is not None:
696
697
698
699
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
700
701
702
703
704
705
706
707
708
709
            input_tensor = (
                self.upsample(input_tensor, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(input_tensor)
            )
            hidden_states = (
                self.upsample(hidden_states, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(hidden_states)
            )
710
        elif self.downsample is not None:
711
712
713
714
715
716
717
718
719
720
            input_tensor = (
                self.downsample(input_tensor, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(input_tensor)
            )
            hidden_states = (
                self.downsample(hidden_states, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(hidden_states)
            )
721

722
        hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
723

724
        if self.time_emb_proj is not None:
725
726
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
727
728
729
730
731
            temb = (
                self.time_emb_proj(temb, scale)[:, :, None, None]
                if not USE_PEFT_BACKEND
                else self.time_emb_proj(temb)[:, :, None, None]
            )
Will Berman's avatar
Will Berman committed
732
733

        if temb is not None and self.time_embedding_norm == "default":
734
            hidden_states = hidden_states + temb
735

YiYi Xu's avatar
YiYi Xu committed
736
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
737
738
739
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
740
741
742
743
744

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

745
        hidden_states = self.nonlinearity(hidden_states)
746

747
        hidden_states = self.dropout(hidden_states)
748
        hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
749
750

        if self.conv_shortcut is not None:
751
752
753
            input_tensor = (
                self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
            )
754

755
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
756

757
        return output_tensor
758

Patrick von Platen's avatar
Patrick von Platen committed
759

760
# unet_rl.py
761
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
762
763
764
765
766
767
768
769
770
771
772
773
774
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
775
776
777
778
779
780

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
        n_groups (`int`, default `8`): Number of groups to separate the channels into.
781
        activation (`str`, defaults `mish`): Name of the activation function.
782
783
    """

784
    def __init__(
785
786
787
788
789
790
        self,
        inp_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        n_groups: int = 8,
        activation: str = "mish",
791
    ):
792
793
794
795
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
796
        self.mish = get_activation(activation)
797

798
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
799
800
801
802
803
804
        intermediate_repr = self.conv1d(inputs)
        intermediate_repr = rearrange_dims(intermediate_repr)
        intermediate_repr = self.group_norm(intermediate_repr)
        intermediate_repr = rearrange_dims(intermediate_repr)
        output = self.mish(intermediate_repr)
        return output
805
806
807
808


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
809
810
811
812
813
814
815
816
    """
    Residual 1D block with temporal convolutions.

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        embed_dim (`int`): Embedding dimension.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
817
        activation (`str`, defaults `mish`): It is possible to choose the right activation function.
818
819
820
    """

    def __init__(
821
822
823
824
825
826
        self,
        inp_channels: int,
        out_channels: int,
        embed_dim: int,
        kernel_size: Union[int, Tuple[int, int]] = 5,
        activation: str = "mish",
827
    ):
828
829
830
831
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

832
        self.time_emb_act = get_activation(activation)
833
834
835
836
837
838
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

839
    def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
840
841
        """
        Args:
842
            inputs : [ batch_size x inp_channels x horizon ]
843
844
845
846
847
848
849
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
850
        out = self.conv_in(inputs) + rearrange_dims(t)
851
        out = self.conv_out(out)
852
        return out + self.residual_conv(inputs)
853
854


855
856
857
def upsample_2d(
    hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor:
858
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
859
860
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
861
862
863
864
865
866
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
867
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
868
869
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
870
871

    Returns:
872
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
873
874
    """
    assert isinstance(factor, int) and factor >= 1
875
876
    if kernel is None:
        kernel = [1] * factor
877

878
    kernel = torch.tensor(kernel, dtype=torch.float32)
879
    if kernel.ndim == 1:
880
881
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
882

883
    kernel = kernel * (gain * (factor**2))
884
    pad_value = kernel.shape[0] - factor
885
    output = upfirdn2d_native(
886
887
888
889
890
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
891
    return output
Patrick von Platen's avatar
Patrick von Platen committed
892
893


894
895
896
def downsample_2d(
    hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor:
897
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
898
899
900
901
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
902
903
904

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
905
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
906
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
907
908
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
909
910

    Returns:
911
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
912
913
914
    """

    assert isinstance(factor, int) and factor >= 1
915
916
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
917

918
    kernel = torch.tensor(kernel, dtype=torch.float32)
919
    if kernel.ndim == 1:
920
921
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
922

923
    kernel = kernel * gain
924
    pad_value = kernel.shape[0] - factor
925
    output = upfirdn2d_native(
926
927
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
928
    return output
929
930


931
932
933
def upfirdn2d_native(
    tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
) -> torch.Tensor:
934
935
936
937
938
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

939
940
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
941

942
    _, in_h, in_w, minor = tensor.shape
943
944
    kernel_h, kernel_w = kernel.shape

945
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
946
947
948
949
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
950
    out = out.to(tensor.device)  # Move back to mps if necessary
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
975
976
977
978
979
980


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
981
982
983
984
985

    Parameters:
        in_dim (`int`): Number of input channels.
        out_dim (`int`): Number of output channels.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
986
987
    """

Dhruv Nair's avatar
Dhruv Nair committed
988
    def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
989
990
991
992
993
994
995
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
996
            nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
997
998
        )
        self.conv2 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
999
            nn.GroupNorm(norm_num_groups, out_dim),
1000
1001
1002
1003
1004
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1005
            nn.GroupNorm(norm_num_groups, out_dim),
1006
1007
1008
1009
1010
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
1011
            nn.GroupNorm(norm_num_groups, out_dim),
1012
1013
1014
1015
1016
1017
1018
1019
1020
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

1021
    def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states