resnet.py 34.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from .activations import get_activation
24
from .attention import AdaGroupNorm
YiYi Xu's avatar
YiYi Xu committed
25
from .attention_processor import SpatialNorm
26
from .lora import LoRACompatibleConv, LoRACompatibleLinear
27

28

29
class Upsample1D(nn.Module):
30
    """A 1D upsampling layer with an optional convolution.
31
32

    Parameters:
33
34
35
36
37
38
39
40
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

57
58
    def forward(self, inputs):
        assert inputs.shape[1] == self.channels
59
        if self.use_conv_transpose:
60
            return self.conv(inputs)
61

62
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
63
64

        if self.use_conv:
65
            outputs = self.conv(outputs)
66

67
        return outputs
68
69
70


class Downsample1D(nn.Module):
71
    """A 1D downsampling layer with an optional convolution.
72
73

    Parameters:
74
75
76
77
78
79
80
81
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

99
100
101
    def forward(self, inputs):
        assert inputs.shape[1] == self.channels
        return self.conv(inputs)
102
103


104
class Upsample2D(nn.Module):
105
    """A 2D upsampling layer with an optional convolution.
106

107
    Parameters:
108
109
110
111
112
113
114
115
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
116
117
    """

118
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
119
120
121
122
123
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
124
        self.name = name
125

patil-suraj's avatar
patil-suraj committed
126
        conv = None
127
        if use_conv_transpose:
128
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
129
        elif use_conv:
130
            conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
131

132
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
133
134
135
136
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
137

138
    def forward(self, hidden_states, output_size=None):
139
        assert hidden_states.shape[1] == self.channels
140

141
        if self.use_conv_transpose:
142
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
143

144
145
146
147
148
149
150
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

151
152
153
154
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

155
156
157
158
159
160
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
161

162
163
164
165
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

166
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
167
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
168
            if self.name == "conv":
169
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
170
            else:
171
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
172

173
        return hidden_states
174
175


176
class Downsample2D(nn.Module):
177
    """A 2D downsampling layer with an optional convolution.
178

179
    Parameters:
180
181
182
183
184
185
186
187
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
188
189
    """

190
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
191
192
193
194
195
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
196
        stride = 2
patil-suraj's avatar
patil-suraj committed
197
198
        self.name = name

199
        if use_conv:
200
            conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
201
202
        else:
            assert self.channels == self.out_channels
203
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
204

205
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
206
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
207
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
208
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
209
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
210
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
211
        else:
Patrick von Platen's avatar
Patrick von Platen committed
212
            self.conv = conv
213

214
215
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
216
        if self.use_conv and self.padding == 0:
217
            pad = (0, 1, 0, 1)
218
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
219

220
221
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
222

223
        return hidden_states
224
225
226


class FirUpsample2D(nn.Module):
227
228
229
230
231
232
233
234
235
236
237
238
239
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

240
241
242
243
244
245
246
247
248
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

249
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
250
251
252
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
253
254
255
256
257
258
259
260
261
262
263
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
264
265

        Returns:
266
267
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
268
269
270
271
272
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
273
274
        if kernel is None:
            kernel = [1] * factor
275
276

        # setup kernel
277
        kernel = torch.tensor(kernel, dtype=torch.float32)
278
        if kernel.ndim == 1:
279
280
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
281

282
        kernel = kernel * (gain * (factor**2))
283
284

        if self.use_conv:
285
286
287
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
288

289
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
290
291
292

            stride = (factor, factor)
            # Determine data dimensions.
293
294
295
296
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
297
            output_padding = (
298
299
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
300
301
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
302
            num_groups = hidden_states.shape[1] // inC
303
304

            # Transpose weights.
305
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
306
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
307
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
308

309
310
311
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
312

313
314
315
316
317
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
318
        else:
319
320
321
322
323
324
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
325
326
            )

327
        return output
328

329
    def forward(self, hidden_states):
330
        if self.use_conv:
331
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
332
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
333
        else:
334
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
335

336
        return height
337
338
339


class FirDownsample2D(nn.Module):
340
341
342
343
344
345
346
347
348
349
350
351
352
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

353
354
355
356
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
357
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
358
359
360
361
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

362
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
363
        """Fused `Conv2d()` followed by `downsample_2d()`.
364
365
366
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
367
368

        Args:
369
370
371
372
373
374
375
376
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
377
378

        Returns:
379
380
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
381
        """
382

383
        assert isinstance(factor, int) and factor >= 1
384
385
        if kernel is None:
            kernel = [1] * factor
386

387
        # setup kernel
388
        kernel = torch.tensor(kernel, dtype=torch.float32)
389
        if kernel.ndim == 1:
390
391
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
392

393
        kernel = kernel * gain
394

395
        if self.use_conv:
396
            _, _, convH, convW = weight.shape
397
398
399
400
401
402
403
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
404
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
405
        else:
406
            pad_value = kernel.shape[0] - factor
407
            output = upfirdn2d_native(
408
409
410
411
412
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
413

414
        return output
415

416
    def forward(self, hidden_states):
417
        if self.use_conv:
418
419
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
420
        else:
421
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
422

423
        return hidden_states
424
425


426
427
428
429
430
431
432
433
434
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

435
436
437
438
439
    def forward(self, inputs):
        inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
440
        weight[indices, indices] = kernel
441
        return F.conv2d(inputs, weight, stride=2)
442
443
444
445
446
447
448
449
450
451


class KUpsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

452
453
454
455
456
    def forward(self, inputs):
        inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
457
        weight[indices, indices] = kernel
458
        return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
459
460


461
class ResnetBlock2D(nn.Module):
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
479
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
480
481
482
483
484
485
486
487
488
489
490
491
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

492
493
494
495
496
497
498
499
500
501
502
503
504
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
505
        skip_time_act=False,
YiYi Xu's avatar
YiYi Xu committed
506
        time_embedding_norm="default",  # default, scale_shift, ada_group, spatial
507
508
        kernel=None,
        output_scale_factor=1.0,
509
        use_in_shortcut=None,
510
511
        up=False,
        down=False,
512
513
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
514
515
516
517
518
519
520
521
522
523
524
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
525
        self.time_embedding_norm = time_embedding_norm
526
        self.skip_time_act = skip_time_act
527
528
529
530

        if groups_out is None:
            groups_out = groups

531
532
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
533
534
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
535
536
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
537

538
        self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
539

540
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
541
            if self.time_embedding_norm == "default":
542
                self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
543
            elif self.time_embedding_norm == "scale_shift":
544
                self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
YiYi Xu's avatar
YiYi Xu committed
545
            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
546
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
547
548
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
549
550
        else:
            self.time_emb_proj = None
551

552
553
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
554
555
        elif self.time_embedding_norm == "spatial":
            self.norm2 = SpatialNorm(out_channels, temb_channels)
556
557
558
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

559
        self.dropout = torch.nn.Dropout(dropout)
560
        conv_2d_out_channels = conv_2d_out_channels or out_channels
561
        self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
562

563
        self.nonlinearity = get_activation(non_linearity)
564
565
566
567
568

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
569
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
570
571
572
573
574
575
576
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
577
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
578
579
580
581
582
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

583
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
584
585

        self.conv_shortcut = None
586
        if self.use_in_shortcut:
587
            self.conv_shortcut = LoRACompatibleConv(
588
589
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
590

591
592
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
593

YiYi Xu's avatar
YiYi Xu committed
594
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
595
596
597
598
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

599
        hidden_states = self.nonlinearity(hidden_states)
600
601

        if self.upsample is not None:
602
603
604
605
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
606
            input_tensor = self.upsample(input_tensor)
607
            hidden_states = self.upsample(hidden_states)
608
        elif self.downsample is not None:
609
            input_tensor = self.downsample(input_tensor)
610
            hidden_states = self.downsample(hidden_states)
611

612
        hidden_states = self.conv1(hidden_states)
613

614
        if self.time_emb_proj is not None:
615
616
617
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
            temb = self.time_emb_proj(temb)[:, :, None, None]
Will Berman's avatar
Will Berman committed
618
619

        if temb is not None and self.time_embedding_norm == "default":
620
            hidden_states = hidden_states + temb
621

YiYi Xu's avatar
YiYi Xu committed
622
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
623
624
625
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
626
627
628
629
630

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

631
        hidden_states = self.nonlinearity(hidden_states)
632

633
634
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
635
636

        if self.conv_shortcut is not None:
637
            input_tensor = self.conv_shortcut(input_tensor)
638

639
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
640

641
        return output_tensor
642

Patrick von Platen's avatar
Patrick von Platen committed
643

644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
# unet_rl.py
def rearrange_dims(tensor):
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

668
669
670
671
672
673
674
    def forward(self, inputs):
        intermediate_repr = self.conv1d(inputs)
        intermediate_repr = rearrange_dims(intermediate_repr)
        intermediate_repr = self.group_norm(intermediate_repr)
        intermediate_repr = rearrange_dims(intermediate_repr)
        output = self.mish(intermediate_repr)
        return output
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

691
    def forward(self, inputs, t):
692
693
        """
        Args:
694
            inputs : [ batch_size x inp_channels x horizon ]
695
696
697
698
699
700
701
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
702
        out = self.conv_in(inputs) + rearrange_dims(t)
703
        out = self.conv_out(out)
704
        return out + self.residual_conv(inputs)
705
706


707
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
708
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
709
710
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
711
712
713
714
715
716
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
717
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
718
719
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
720
721

    Returns:
722
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
723
724
    """
    assert isinstance(factor, int) and factor >= 1
725
726
    if kernel is None:
        kernel = [1] * factor
727

728
    kernel = torch.tensor(kernel, dtype=torch.float32)
729
    if kernel.ndim == 1:
730
731
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
732

733
    kernel = kernel * (gain * (factor**2))
734
    pad_value = kernel.shape[0] - factor
735
    output = upfirdn2d_native(
736
737
738
739
740
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
741
    return output
Patrick von Platen's avatar
Patrick von Platen committed
742
743


744
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
745
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
746
747
748
749
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
750
751
752

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
753
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
754
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
755
756
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
757
758

    Returns:
759
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
760
761
762
    """

    assert isinstance(factor, int) and factor >= 1
763
764
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
765

766
    kernel = torch.tensor(kernel, dtype=torch.float32)
767
    if kernel.ndim == 1:
768
769
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
770

771
    kernel = kernel * gain
772
    pad_value = kernel.shape[0] - factor
773
    output = upfirdn2d_native(
774
775
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
776
    return output
777
778


779
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
780
781
782
783
784
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

785
786
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
787

788
    _, in_h, in_w, minor = tensor.shape
789
790
    kernel_h, kernel_w = kernel.shape

791
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
792
793
794
795
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
796
    out = out.to(tensor.device)  # Move back to mps if necessary
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
    """

    def __init__(self, in_dim, out_dim=None, dropout=0.0):
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
            nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
        )
        self.conv2 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

    def forward(self, hidden_states, num_frames=1):
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states