resnet.py 34.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from .activations import get_activation
24
from .attention import AdaGroupNorm
YiYi Xu's avatar
YiYi Xu committed
25
from .attention_processor import SpatialNorm
26

27

28
class Upsample1D(nn.Module):
29
    """A 1D upsampling layer with an optional convolution.
30
31

    Parameters:
32
33
34
35
36
37
38
39
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

56
57
    def forward(self, inputs):
        assert inputs.shape[1] == self.channels
58
        if self.use_conv_transpose:
59
            return self.conv(inputs)
60

61
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
62
63

        if self.use_conv:
64
            outputs = self.conv(outputs)
65

66
        return outputs
67
68
69


class Downsample1D(nn.Module):
70
    """A 1D downsampling layer with an optional convolution.
71
72

    Parameters:
73
74
75
76
77
78
79
80
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.conv(x)


103
class Upsample2D(nn.Module):
104
    """A 2D upsampling layer with an optional convolution.
105

106
    Parameters:
107
108
109
110
111
112
113
114
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
115
116
    """

117
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
118
119
120
121
122
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
123
        self.name = name
124

patil-suraj's avatar
patil-suraj committed
125
        conv = None
126
        if use_conv_transpose:
127
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
128
        elif use_conv:
129
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
130

131
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
132
133
134
135
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
136

137
    def forward(self, hidden_states, output_size=None):
138
        assert hidden_states.shape[1] == self.channels
139

140
        if self.use_conv_transpose:
141
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
142

143
144
145
146
147
148
149
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

150
151
152
153
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

154
155
156
157
158
159
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
160

161
162
163
164
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

165
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
166
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
167
            if self.name == "conv":
168
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
169
            else:
170
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
171

172
        return hidden_states
173
174


175
class Downsample2D(nn.Module):
176
    """A 2D downsampling layer with an optional convolution.
177

178
    Parameters:
179
180
181
182
183
184
185
186
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
187
188
    """

189
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
190
191
192
193
194
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
195
        stride = 2
patil-suraj's avatar
patil-suraj committed
196
197
        self.name = name

198
        if use_conv:
199
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
200
201
        else:
            assert self.channels == self.out_channels
202
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
203

204
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
205
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
206
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
207
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
208
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
209
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
210
        else:
Patrick von Platen's avatar
Patrick von Platen committed
211
            self.conv = conv
212

213
214
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
215
        if self.use_conv and self.padding == 0:
216
            pad = (0, 1, 0, 1)
217
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
218

219
220
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
221

222
        return hidden_states
223
224
225


class FirUpsample2D(nn.Module):
226
227
228
229
230
231
232
233
234
235
236
237
238
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

239
240
241
242
243
244
245
246
247
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

248
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
249
250
251
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
252
253
254
255
256
257
258
259
260
261
262
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
263
264

        Returns:
265
266
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
267
268
269
270
271
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
272
273
        if kernel is None:
            kernel = [1] * factor
274
275

        # setup kernel
276
        kernel = torch.tensor(kernel, dtype=torch.float32)
277
        if kernel.ndim == 1:
278
279
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
280

281
        kernel = kernel * (gain * (factor**2))
282
283

        if self.use_conv:
284
285
286
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
287

288
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
289
290
291

            stride = (factor, factor)
            # Determine data dimensions.
292
293
294
295
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
296
            output_padding = (
297
298
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
299
300
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
301
            num_groups = hidden_states.shape[1] // inC
302
303

            # Transpose weights.
304
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
305
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
306
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
307

308
309
310
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
311

312
313
314
315
316
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
317
        else:
318
319
320
321
322
323
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
324
325
            )

326
        return output
327

328
    def forward(self, hidden_states):
329
        if self.use_conv:
330
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
331
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
332
        else:
333
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
334

335
        return height
336
337
338


class FirDownsample2D(nn.Module):
339
340
341
342
343
344
345
346
347
348
349
350
351
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

352
353
354
355
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
356
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
357
358
359
360
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

361
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
362
        """Fused `Conv2d()` followed by `downsample_2d()`.
363
364
365
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
366
367

        Args:
368
369
370
371
372
373
374
375
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
376
377

        Returns:
378
379
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
380
        """
381

382
        assert isinstance(factor, int) and factor >= 1
383
384
        if kernel is None:
            kernel = [1] * factor
385

386
        # setup kernel
387
        kernel = torch.tensor(kernel, dtype=torch.float32)
388
        if kernel.ndim == 1:
389
390
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
391

392
        kernel = kernel * gain
393

394
        if self.use_conv:
395
            _, _, convH, convW = weight.shape
396
397
398
399
400
401
402
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
403
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
404
        else:
405
            pad_value = kernel.shape[0] - factor
406
            output = upfirdn2d_native(
407
408
409
410
411
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
412

413
        return output
414

415
    def forward(self, hidden_states):
416
        if self.use_conv:
417
418
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
419
        else:
420
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
421

422
        return hidden_states
423
424


425
426
427
428
429
430
431
432
433
434
435
436
437
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    def forward(self, x):
        x = F.pad(x, (self.pad,) * 4, self.pad_mode)
        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(x.shape[1], device=x.device)
438
439
        kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
        weight[indices, indices] = kernel
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        return F.conv2d(x, weight, stride=2)


class KUpsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    def forward(self, x):
        x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(x.shape[1], device=x.device)
455
456
        kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
        weight[indices, indices] = kernel
457
458
459
        return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)


460
class ResnetBlock2D(nn.Module):
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
478
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
479
480
481
482
483
484
485
486
487
488
489
490
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

491
492
493
494
495
496
497
498
499
500
501
502
503
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
504
        skip_time_act=False,
YiYi Xu's avatar
YiYi Xu committed
505
        time_embedding_norm="default",  # default, scale_shift, ada_group, spatial
506
507
        kernel=None,
        output_scale_factor=1.0,
508
        use_in_shortcut=None,
509
510
        up=False,
        down=False,
511
512
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
513
514
515
516
517
518
519
520
521
522
523
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
524
        self.time_embedding_norm = time_embedding_norm
525
        self.skip_time_act = skip_time_act
526
527
528
529

        if groups_out is None:
            groups_out = groups

530
531
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
532
533
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
534
535
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
536
537
538

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

539
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
540
            if self.time_embedding_norm == "default":
541
                self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
542
            elif self.time_embedding_norm == "scale_shift":
543
                self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
YiYi Xu's avatar
YiYi Xu committed
544
            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
545
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
546
547
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
548
549
        else:
            self.time_emb_proj = None
550

551
552
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
553
554
        elif self.time_embedding_norm == "spatial":
            self.norm2 = SpatialNorm(out_channels, temb_channels)
555
556
557
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

558
        self.dropout = torch.nn.Dropout(dropout)
559
560
        conv_2d_out_channels = conv_2d_out_channels or out_channels
        self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
561

562
        self.nonlinearity = get_activation(non_linearity)
563
564
565
566
567

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
568
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
569
570
571
572
573
574
575
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
576
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
577
578
579
580
581
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

582
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
583
584

        self.conv_shortcut = None
585
        if self.use_in_shortcut:
586
587
588
            self.conv_shortcut = torch.nn.Conv2d(
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
589

590
591
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
592

YiYi Xu's avatar
YiYi Xu committed
593
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
594
595
596
597
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

598
        hidden_states = self.nonlinearity(hidden_states)
599
600

        if self.upsample is not None:
601
602
603
604
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
605
            input_tensor = self.upsample(input_tensor)
606
            hidden_states = self.upsample(hidden_states)
607
        elif self.downsample is not None:
608
            input_tensor = self.downsample(input_tensor)
609
            hidden_states = self.downsample(hidden_states)
610

611
        hidden_states = self.conv1(hidden_states)
612

613
        if self.time_emb_proj is not None:
614
615
616
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
            temb = self.time_emb_proj(temb)[:, :, None, None]
Will Berman's avatar
Will Berman committed
617
618

        if temb is not None and self.time_embedding_norm == "default":
619
            hidden_states = hidden_states + temb
620

YiYi Xu's avatar
YiYi Xu committed
621
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
622
623
624
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
625
626
627
628
629

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

630
        hidden_states = self.nonlinearity(hidden_states)
631

632
633
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
634
635

        if self.conv_shortcut is not None:
636
            input_tensor = self.conv_shortcut(input_tensor)
637

638
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
639

640
        return output_tensor
641

Patrick von Platen's avatar
Patrick von Platen committed
642

643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
# unet_rl.py
def rearrange_dims(tensor):
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

    def forward(self, x):
        x = self.conv1d(x)
        x = rearrange_dims(x)
        x = self.group_norm(x)
        x = rearrange_dims(x)
        x = self.mish(x)
        return x


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

    def forward(self, x, t):
        """
        Args:
            x : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
        out = self.conv_in(x) + rearrange_dims(t)
        out = self.conv_out(out)
        return out + self.residual_conv(x)


706
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
707
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
708
709
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
710
711
712
713
714
715
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
716
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
717
718
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
719
720

    Returns:
721
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
722
723
    """
    assert isinstance(factor, int) and factor >= 1
724
725
    if kernel is None:
        kernel = [1] * factor
726

727
    kernel = torch.tensor(kernel, dtype=torch.float32)
728
    if kernel.ndim == 1:
729
730
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
731

732
    kernel = kernel * (gain * (factor**2))
733
    pad_value = kernel.shape[0] - factor
734
    output = upfirdn2d_native(
735
736
737
738
739
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
740
    return output
Patrick von Platen's avatar
Patrick von Platen committed
741
742


743
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
744
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
745
746
747
748
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
749
750
751

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
752
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
753
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
754
755
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
756
757

    Returns:
758
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
759
760
761
    """

    assert isinstance(factor, int) and factor >= 1
762
763
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
764

765
    kernel = torch.tensor(kernel, dtype=torch.float32)
766
    if kernel.ndim == 1:
767
768
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
769

770
    kernel = kernel * gain
771
    pad_value = kernel.shape[0] - factor
772
    output = upfirdn2d_native(
773
774
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
775
    return output
776
777


778
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
779
780
781
782
783
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

784
785
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
786

787
    _, in_h, in_w, minor = tensor.shape
788
789
    kernel_h, kernel_w = kernel.shape

790
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
791
792
793
794
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
795
    out = out.to(tensor.device)  # Move back to mps if necessary
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
    """

    def __init__(self, in_dim, out_dim=None, dropout=0.0):
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
            nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
        )
        self.conv2 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

    def forward(self, hidden_states, num_frames=1):
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states