resnet.py 35.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from .activations import get_activation
24
from .attention import AdaGroupNorm
YiYi Xu's avatar
YiYi Xu committed
25
from .attention_processor import SpatialNorm
26
from .lora import LoRACompatibleConv, LoRACompatibleLinear
27

28

29
class Upsample1D(nn.Module):
30
    """A 1D upsampling layer with an optional convolution.
31
32

    Parameters:
33
34
35
36
37
38
39
40
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

57
58
    def forward(self, inputs):
        assert inputs.shape[1] == self.channels
59
        if self.use_conv_transpose:
60
            return self.conv(inputs)
61

62
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
63
64

        if self.use_conv:
65
            outputs = self.conv(outputs)
66

67
        return outputs
68
69
70


class Downsample1D(nn.Module):
71
    """A 1D downsampling layer with an optional convolution.
72
73

    Parameters:
74
75
76
77
78
79
80
81
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

99
100
101
    def forward(self, inputs):
        assert inputs.shape[1] == self.channels
        return self.conv(inputs)
102
103


104
class Upsample2D(nn.Module):
105
    """A 2D upsampling layer with an optional convolution.
106

107
    Parameters:
108
109
110
111
112
113
114
115
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
116
117
    """

118
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
119
120
121
122
123
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
124
        self.name = name
125

patil-suraj's avatar
patil-suraj committed
126
        conv = None
127
        if use_conv_transpose:
128
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
129
        elif use_conv:
130
            conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
131

132
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
133
134
135
136
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
137

138
    def forward(self, hidden_states, output_size=None, scale: float = 1.0):
139
        assert hidden_states.shape[1] == self.channels
140

141
        if self.use_conv_transpose:
142
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
143

144
145
146
147
148
149
150
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

151
152
153
154
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

155
156
157
158
159
160
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
161

162
163
164
165
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

166
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
167
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
168
            if self.name == "conv":
169
170
171
172
                if isinstance(self.conv, LoRACompatibleConv):
                    hidden_states = self.conv(hidden_states, scale)
                else:
                    hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
173
            else:
174
175
176
177
                if isinstance(self.Conv2d_0, LoRACompatibleConv):
                    hidden_states = self.Conv2d_0(hidden_states, scale)
                else:
                    hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
178

179
        return hidden_states
180
181


182
class Downsample2D(nn.Module):
183
    """A 2D downsampling layer with an optional convolution.
184

185
    Parameters:
186
187
188
189
190
191
192
193
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
194
195
    """

196
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
197
198
199
200
201
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
202
        stride = 2
patil-suraj's avatar
patil-suraj committed
203
204
        self.name = name

205
        if use_conv:
206
            conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
207
208
        else:
            assert self.channels == self.out_channels
209
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
210

211
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
212
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
213
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
214
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
215
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
216
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
217
        else:
Patrick von Platen's avatar
Patrick von Platen committed
218
            self.conv = conv
219

220
    def forward(self, hidden_states, scale: float = 1.0):
221
        assert hidden_states.shape[1] == self.channels
222
        if self.use_conv and self.padding == 0:
223
            pad = (0, 1, 0, 1)
224
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
225

226
        assert hidden_states.shape[1] == self.channels
227
228
229
230
        if isinstance(self.conv, LoRACompatibleConv):
            hidden_states = self.conv(hidden_states, scale)
        else:
            hidden_states = self.conv(hidden_states)
231

232
        return hidden_states
233
234
235


class FirUpsample2D(nn.Module):
236
237
238
239
240
241
242
243
244
245
246
247
248
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

249
250
251
252
253
254
255
256
257
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

258
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
259
260
261
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
262
263
264
265
266
267
268
269
270
271
272
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
273
274

        Returns:
275
276
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
277
278
279
280
281
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
282
283
        if kernel is None:
            kernel = [1] * factor
284
285

        # setup kernel
286
        kernel = torch.tensor(kernel, dtype=torch.float32)
287
        if kernel.ndim == 1:
288
289
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
290

291
        kernel = kernel * (gain * (factor**2))
292
293

        if self.use_conv:
294
295
296
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
297

298
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
299
300
301

            stride = (factor, factor)
            # Determine data dimensions.
302
303
304
305
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
306
            output_padding = (
307
308
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
309
310
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
311
            num_groups = hidden_states.shape[1] // inC
312
313

            # Transpose weights.
314
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
315
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
316
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
317

318
319
320
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
321

322
323
324
325
326
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
327
        else:
328
329
330
331
332
333
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
334
335
            )

336
        return output
337

338
    def forward(self, hidden_states):
339
        if self.use_conv:
340
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
341
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
342
        else:
343
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
344

345
        return height
346
347
348


class FirDownsample2D(nn.Module):
349
350
351
352
353
354
355
356
357
358
359
360
361
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

362
363
364
365
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
366
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
367
368
369
370
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

371
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
372
        """Fused `Conv2d()` followed by `downsample_2d()`.
373
374
375
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
376
377

        Args:
378
379
380
381
382
383
384
385
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
386
387

        Returns:
388
389
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
390
        """
391

392
        assert isinstance(factor, int) and factor >= 1
393
394
        if kernel is None:
            kernel = [1] * factor
395

396
        # setup kernel
397
        kernel = torch.tensor(kernel, dtype=torch.float32)
398
        if kernel.ndim == 1:
399
400
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
401

402
        kernel = kernel * gain
403

404
        if self.use_conv:
405
            _, _, convH, convW = weight.shape
406
407
408
409
410
411
412
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
413
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
414
        else:
415
            pad_value = kernel.shape[0] - factor
416
            output = upfirdn2d_native(
417
418
419
420
421
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
422

423
        return output
424

425
    def forward(self, hidden_states):
426
        if self.use_conv:
427
428
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
429
        else:
430
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
431

432
        return hidden_states
433
434


435
436
437
438
439
440
441
442
443
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

444
445
446
447
448
    def forward(self, inputs):
        inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
449
        weight[indices, indices] = kernel
450
        return F.conv2d(inputs, weight, stride=2)
451
452
453
454
455
456
457
458
459
460


class KUpsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

461
462
463
464
465
    def forward(self, inputs):
        inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
466
        weight[indices, indices] = kernel
467
        return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
468
469


470
class ResnetBlock2D(nn.Module):
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
488
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
489
490
491
492
493
494
495
496
497
498
499
500
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

501
502
503
504
505
506
507
508
509
510
511
512
513
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
514
        skip_time_act=False,
YiYi Xu's avatar
YiYi Xu committed
515
        time_embedding_norm="default",  # default, scale_shift, ada_group, spatial
516
517
        kernel=None,
        output_scale_factor=1.0,
518
        use_in_shortcut=None,
519
520
        up=False,
        down=False,
521
522
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
523
524
525
526
527
528
529
530
531
532
533
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
534
        self.time_embedding_norm = time_embedding_norm
535
        self.skip_time_act = skip_time_act
536
537
538
539

        if groups_out is None:
            groups_out = groups

540
541
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
542
543
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
544
545
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
546

547
        self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
548

549
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
550
            if self.time_embedding_norm == "default":
551
                self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
552
            elif self.time_embedding_norm == "scale_shift":
553
                self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
YiYi Xu's avatar
YiYi Xu committed
554
            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
555
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
556
557
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
558
559
        else:
            self.time_emb_proj = None
560

561
562
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
563
564
        elif self.time_embedding_norm == "spatial":
            self.norm2 = SpatialNorm(out_channels, temb_channels)
565
566
567
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

568
        self.dropout = torch.nn.Dropout(dropout)
569
        conv_2d_out_channels = conv_2d_out_channels or out_channels
570
        self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
571

572
        self.nonlinearity = get_activation(non_linearity)
573
574
575
576
577

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
578
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
579
580
581
582
583
584
585
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
586
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
587
588
589
590
591
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

592
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
593
594

        self.conv_shortcut = None
595
        if self.use_in_shortcut:
596
            self.conv_shortcut = LoRACompatibleConv(
597
598
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
599

600
    def forward(self, input_tensor, temb, scale: float = 1.0):
601
        hidden_states = input_tensor
602

YiYi Xu's avatar
YiYi Xu committed
603
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
604
605
606
607
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

608
        hidden_states = self.nonlinearity(hidden_states)
609
610

        if self.upsample is not None:
611
612
613
614
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
615
616
617
618
619
620
621
622
623
624
            input_tensor = (
                self.upsample(input_tensor, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(input_tensor)
            )
            hidden_states = (
                self.upsample(hidden_states, scale=scale)
                if isinstance(self.upsample, Upsample2D)
                else self.upsample(hidden_states)
            )
625
        elif self.downsample is not None:
626
627
628
629
630
631
632
633
634
635
            input_tensor = (
                self.downsample(input_tensor, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(input_tensor)
            )
            hidden_states = (
                self.downsample(hidden_states, scale=scale)
                if isinstance(self.downsample, Downsample2D)
                else self.downsample(hidden_states)
            )
636

637
        hidden_states = self.conv1(hidden_states, scale)
638

639
        if self.time_emb_proj is not None:
640
641
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
642
            temb = self.time_emb_proj(temb, scale)[:, :, None, None]
Will Berman's avatar
Will Berman committed
643
644

        if temb is not None and self.time_embedding_norm == "default":
645
            hidden_states = hidden_states + temb
646

YiYi Xu's avatar
YiYi Xu committed
647
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
648
649
650
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
651
652
653
654
655

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

656
        hidden_states = self.nonlinearity(hidden_states)
657

658
        hidden_states = self.dropout(hidden_states)
659
        hidden_states = self.conv2(hidden_states, scale)
660
661

        if self.conv_shortcut is not None:
662
            input_tensor = self.conv_shortcut(input_tensor, scale)
663

664
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
665

666
        return output_tensor
667

Patrick von Platen's avatar
Patrick von Platen committed
668

669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
# unet_rl.py
def rearrange_dims(tensor):
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

693
694
695
696
697
698
699
    def forward(self, inputs):
        intermediate_repr = self.conv1d(inputs)
        intermediate_repr = rearrange_dims(intermediate_repr)
        intermediate_repr = self.group_norm(intermediate_repr)
        intermediate_repr = rearrange_dims(intermediate_repr)
        output = self.mish(intermediate_repr)
        return output
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

716
    def forward(self, inputs, t):
717
718
        """
        Args:
719
            inputs : [ batch_size x inp_channels x horizon ]
720
721
722
723
724
725
726
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
727
        out = self.conv_in(inputs) + rearrange_dims(t)
728
        out = self.conv_out(out)
729
        return out + self.residual_conv(inputs)
730
731


732
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
733
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
734
735
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
736
737
738
739
740
741
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
742
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
743
744
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
745
746

    Returns:
747
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
748
749
    """
    assert isinstance(factor, int) and factor >= 1
750
751
    if kernel is None:
        kernel = [1] * factor
752

753
    kernel = torch.tensor(kernel, dtype=torch.float32)
754
    if kernel.ndim == 1:
755
756
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
757

758
    kernel = kernel * (gain * (factor**2))
759
    pad_value = kernel.shape[0] - factor
760
    output = upfirdn2d_native(
761
762
763
764
765
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
766
    return output
Patrick von Platen's avatar
Patrick von Platen committed
767
768


769
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
770
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
771
772
773
774
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
775
776
777

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
778
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
779
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
780
781
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
782
783

    Returns:
784
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
785
786
787
    """

    assert isinstance(factor, int) and factor >= 1
788
789
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
790

791
    kernel = torch.tensor(kernel, dtype=torch.float32)
792
    if kernel.ndim == 1:
793
794
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
795

796
    kernel = kernel * gain
797
    pad_value = kernel.shape[0] - factor
798
    output = upfirdn2d_native(
799
800
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
801
    return output
802
803


804
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
805
806
807
808
809
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

810
811
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
812

813
    _, in_h, in_w, minor = tensor.shape
814
815
    kernel_h, kernel_w = kernel.shape

816
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
817
818
819
820
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
821
    out = out.to(tensor.device)  # Move back to mps if necessary
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
    """

    def __init__(self, in_dim, out_dim=None, dropout=0.0):
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
            nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
        )
        self.conv2 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

    def forward(self, hidden_states, num_frames=1):
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states