"docs/vscode:/vscode.git/clone" did not exist on "a0127e1712cd58b22736d20ed4e5531e1d277f5e"
resnet.py 34.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
17
from typing import Optional
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from .attention import AdaGroupNorm
YiYi Xu's avatar
YiYi Xu committed
24
from .attention_processor import SpatialNorm
25

26

27
class Upsample1D(nn.Module):
28
    """A 1D upsampling layer with an optional convolution.
29
30

    Parameters:
31
32
33
34
35
36
37
38
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    """

    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name

        self.conv = None
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

55
56
    def forward(self, inputs):
        assert inputs.shape[1] == self.channels
57
        if self.use_conv_transpose:
58
            return self.conv(inputs)
59

60
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
61
62

        if self.use_conv:
63
            outputs = self.conv(outputs)
64

65
        return outputs
66
67
68


class Downsample1D(nn.Module):
69
    """A 1D downsampling layer with an optional convolution.
70
71

    Parameters:
72
73
74
75
76
77
78
79
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    """

    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        stride = 2
        self.name = name

        if use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.conv(x)


102
class Upsample2D(nn.Module):
103
    """A 2D upsampling layer with an optional convolution.
104

105
    Parameters:
106
107
108
109
110
111
112
113
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
114
115
    """

116
    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
117
118
119
120
121
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
patil-suraj's avatar
patil-suraj committed
122
        self.name = name
123

patil-suraj's avatar
patil-suraj committed
124
        conv = None
125
        if use_conv_transpose:
126
            conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
127
        elif use_conv:
128
            conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
129

130
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
131
132
133
134
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv
135

136
    def forward(self, hidden_states, output_size=None):
137
        assert hidden_states.shape[1] == self.channels
138

139
        if self.use_conv_transpose:
140
            return self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
141

142
143
144
145
146
147
148
        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

149
150
151
152
        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

153
154
155
156
157
158
        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if output_size is None:
            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        else:
            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
patil-suraj's avatar
patil-suraj committed
159

160
161
162
163
        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

164
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
165
        if self.use_conv:
patil-suraj's avatar
patil-suraj committed
166
            if self.name == "conv":
167
                hidden_states = self.conv(hidden_states)
patil-suraj's avatar
patil-suraj committed
168
            else:
169
                hidden_states = self.Conv2d_0(hidden_states)
patil-suraj's avatar
patil-suraj committed
170

171
        return hidden_states
172
173


174
class Downsample2D(nn.Module):
175
    """A 2D downsampling layer with an optional convolution.
176

177
    Parameters:
178
179
180
181
182
183
184
185
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
186
187
    """

188
    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
189
190
191
192
193
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
194
        stride = 2
patil-suraj's avatar
patil-suraj committed
195
196
        self.name = name

197
        if use_conv:
198
            conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
199
200
        else:
            assert self.channels == self.out_channels
201
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
patil-suraj's avatar
patil-suraj committed
202

203
        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
patil-suraj's avatar
patil-suraj committed
204
        if name == "conv":
Patrick von Platen's avatar
Patrick von Platen committed
205
            self.Conv2d_0 = conv
patil-suraj's avatar
patil-suraj committed
206
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
207
        elif name == "Conv2d_0":
Patrick von Platen's avatar
Patrick von Platen committed
208
            self.conv = conv
patil-suraj's avatar
patil-suraj committed
209
        else:
Patrick von Platen's avatar
Patrick von Platen committed
210
            self.conv = conv
211

212
213
    def forward(self, hidden_states):
        assert hidden_states.shape[1] == self.channels
214
        if self.use_conv and self.padding == 0:
215
            pad = (0, 1, 0, 1)
216
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
patil-suraj's avatar
patil-suraj committed
217

218
219
        assert hidden_states.shape[1] == self.channels
        hidden_states = self.conv(hidden_states)
220

221
        return hidden_states
222
223
224


class FirUpsample2D(nn.Module):
225
226
227
228
229
230
231
232
233
234
235
236
237
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

238
239
240
241
242
243
244
245
246
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

247
    def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
248
249
250
        """Fused `upsample_2d()` followed by `Conv2d()`.

        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
251
252
253
254
255
256
257
258
259
260
261
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.

        Args:
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight: Weight tensor of the shape `[filterH, filterW, inChannels,
                outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
                (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
            factor: Integer upsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
262
263

        Returns:
264
265
            output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
            datatype as `hidden_states`.
266
267
268
269
270
        """

        assert isinstance(factor, int) and factor >= 1

        # Setup filter kernel.
271
272
        if kernel is None:
            kernel = [1] * factor
273
274

        # setup kernel
275
        kernel = torch.tensor(kernel, dtype=torch.float32)
276
        if kernel.ndim == 1:
277
278
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
279

280
        kernel = kernel * (gain * (factor**2))
281
282

        if self.use_conv:
283
284
285
            convH = weight.shape[2]
            convW = weight.shape[3]
            inC = weight.shape[1]
286

287
            pad_value = (kernel.shape[0] - factor) - (convW - 1)
288
289
290

            stride = (factor, factor)
            # Determine data dimensions.
291
292
293
294
            output_shape = (
                (hidden_states.shape[2] - 1) * factor + convH,
                (hidden_states.shape[3] - 1) * factor + convW,
            )
295
            output_padding = (
296
297
                output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
                output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
298
299
            )
            assert output_padding[0] >= 0 and output_padding[1] >= 0
300
            num_groups = hidden_states.shape[1] // inC
301
302

            # Transpose weights.
303
            weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
Yih-Dar's avatar
Yih-Dar committed
304
            weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
305
            weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
306

307
308
309
            inverse_conv = F.conv_transpose2d(
                hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
            )
310

311
312
313
314
315
            output = upfirdn2d_native(
                inverse_conv,
                torch.tensor(kernel, device=inverse_conv.device),
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
            )
316
        else:
317
318
319
320
321
322
            pad_value = kernel.shape[0] - factor
            output = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                up=factor,
                pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
323
324
            )

325
        return output
326

327
    def forward(self, hidden_states):
328
        if self.use_conv:
329
            height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
330
            height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
331
        else:
332
            height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
333

334
        return height
335
336
337


class FirDownsample2D(nn.Module):
338
339
340
341
342
343
344
345
346
347
348
349
350
    """A 2D FIR downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

351
352
353
354
    def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_channels = out_channels if out_channels else channels
        if use_conv:
355
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
356
357
358
359
        self.fir_kernel = fir_kernel
        self.use_conv = use_conv
        self.out_channels = out_channels

360
    def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
361
        """Fused `Conv2d()` followed by `downsample_2d()`.
362
363
364
        Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
        efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
        arbitrary order.
365
366

        Args:
367
368
369
370
371
372
373
374
            hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
            weight:
                Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
                performed by `inChannels = x.shape[0] // numGroups`.
            kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
            factor`, which corresponds to average pooling.
            factor: Integer downsampling factor (default: 2).
            gain: Scaling factor for signal magnitude (default: 1.0).
375
376

        Returns:
377
378
            output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
            same datatype as `x`.
379
        """
380

381
        assert isinstance(factor, int) and factor >= 1
382
383
        if kernel is None:
            kernel = [1] * factor
384

385
        # setup kernel
386
        kernel = torch.tensor(kernel, dtype=torch.float32)
387
        if kernel.ndim == 1:
388
389
            kernel = torch.outer(kernel, kernel)
        kernel /= torch.sum(kernel)
390

391
        kernel = kernel * gain
392

393
        if self.use_conv:
394
            _, _, convH, convW = weight.shape
395
396
397
398
399
400
401
            pad_value = (kernel.shape[0] - factor) + (convW - 1)
            stride_value = [factor, factor]
            upfirdn_input = upfirdn2d_native(
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
402
            output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
403
        else:
404
            pad_value = kernel.shape[0] - factor
405
            output = upfirdn2d_native(
406
407
408
409
410
                hidden_states,
                torch.tensor(kernel, device=hidden_states.device),
                down=factor,
                pad=((pad_value + 1) // 2, pad_value // 2),
            )
411

412
        return output
413

414
    def forward(self, hidden_states):
415
        if self.use_conv:
416
417
            downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
            hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
418
        else:
419
            hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
420

421
        return hidden_states
422
423


424
425
426
427
428
429
430
431
432
433
434
435
436
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    def forward(self, x):
        x = F.pad(x, (self.pad,) * 4, self.pad_mode)
        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(x.shape[1], device=x.device)
437
438
        kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
        weight[indices, indices] = kernel
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        return F.conv2d(x, weight, stride=2)


class KUpsample2D(nn.Module):
    def __init__(self, pad_mode="reflect"):
        super().__init__()
        self.pad_mode = pad_mode
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        self.pad = kernel_1d.shape[1] // 2 - 1
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    def forward(self, x):
        x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
        indices = torch.arange(x.shape[1], device=x.device)
454
455
        kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
        weight[indices, indices] = kernel
456
457
458
        return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)


459
class ResnetBlock2D(nn.Module):
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
            "ada_group" for a stronger conditioning with scale and shift.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
477
        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
478
479
480
481
482
483
484
485
486
487
488
489
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

490
491
492
493
494
495
496
497
498
499
500
501
502
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
503
        skip_time_act=False,
YiYi Xu's avatar
YiYi Xu committed
504
        time_embedding_norm="default",  # default, scale_shift, ada_group, spatial
505
506
        kernel=None,
        output_scale_factor=1.0,
507
        use_in_shortcut=None,
508
509
        up=False,
        down=False,
510
511
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
512
513
514
515
516
517
518
519
520
521
522
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
523
        self.time_embedding_norm = time_embedding_norm
524
        self.skip_time_act = skip_time_act
525
526
527
528

        if groups_out is None:
            groups_out = groups

529
530
        if self.time_embedding_norm == "ada_group":
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
531
532
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
533
534
        else:
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
535
536
537

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

538
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
539
            if self.time_embedding_norm == "default":
540
                self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
541
            elif self.time_embedding_norm == "scale_shift":
542
                self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
YiYi Xu's avatar
YiYi Xu committed
543
            elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
544
                self.time_emb_proj = None
Will Berman's avatar
Will Berman committed
545
546
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
547
548
        else:
            self.time_emb_proj = None
549

550
551
        if self.time_embedding_norm == "ada_group":
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
YiYi Xu's avatar
YiYi Xu committed
552
553
        elif self.time_embedding_norm == "spatial":
            self.norm2 = SpatialNorm(out_channels, temb_channels)
554
555
556
        else:
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

557
        self.dropout = torch.nn.Dropout(dropout)
558
559
        conv_2d_out_channels = conv_2d_out_channels or out_channels
        self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
560
561
562
563

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
564
            self.nonlinearity = nn.Mish()
565
566
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()
567
568
        elif non_linearity == "gelu":
            self.nonlinearity = nn.GELU()
569
570
571
572
573

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
574
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
575
576
577
578
579
580
581
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
582
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
583
584
585
586
587
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

588
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
589
590

        self.conv_shortcut = None
591
        if self.use_in_shortcut:
592
593
594
            self.conv_shortcut = torch.nn.Conv2d(
                in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
            )
595

596
597
    def forward(self, input_tensor, temb):
        hidden_states = input_tensor
598

YiYi Xu's avatar
YiYi Xu committed
599
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
600
601
602
603
            hidden_states = self.norm1(hidden_states, temb)
        else:
            hidden_states = self.norm1(hidden_states)

604
        hidden_states = self.nonlinearity(hidden_states)
605
606

        if self.upsample is not None:
607
608
609
610
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
611
            input_tensor = self.upsample(input_tensor)
612
            hidden_states = self.upsample(hidden_states)
613
        elif self.downsample is not None:
614
            input_tensor = self.downsample(input_tensor)
615
            hidden_states = self.downsample(hidden_states)
616

617
        hidden_states = self.conv1(hidden_states)
618

619
        if self.time_emb_proj is not None:
620
621
622
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
            temb = self.time_emb_proj(temb)[:, :, None, None]
Will Berman's avatar
Will Berman committed
623
624

        if temb is not None and self.time_embedding_norm == "default":
625
            hidden_states = hidden_states + temb
626

YiYi Xu's avatar
YiYi Xu committed
627
        if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
628
629
630
            hidden_states = self.norm2(hidden_states, temb)
        else:
            hidden_states = self.norm2(hidden_states)
Will Berman's avatar
Will Berman committed
631
632
633
634
635

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

636
        hidden_states = self.nonlinearity(hidden_states)
637

638
639
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
640
641

        if self.conv_shortcut is not None:
642
            input_tensor = self.conv_shortcut(input_tensor)
643

644
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
645

646
        return output_tensor
647

Patrick von Platen's avatar
Patrick von Platen committed
648
649

class Mish(torch.nn.Module):
650
651
    def forward(self, hidden_states):
        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
652
653


654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
# unet_rl.py
def rearrange_dims(tensor):
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        self.mish = nn.Mish()

    def forward(self, x):
        x = self.conv1d(x)
        x = rearrange_dims(x)
        x = self.group_norm(x)
        x = rearrange_dims(x)
        x = self.mish(x)
        return x


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

        self.time_emb_act = nn.Mish()
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

    def forward(self, x, t):
        """
        Args:
            x : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
        out = self.conv_in(x) + rearrange_dims(t)
        out = self.conv_out(out)
        return out + self.residual_conv(x)


717
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
718
    r"""Upsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
719
720
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
721
722
723
724
725
726
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
727
          (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
728
729
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
730
731

    Returns:
732
        output: Tensor of the shape `[N, C, H * factor, W * factor]`
Patrick von Platen's avatar
Patrick von Platen committed
733
734
    """
    assert isinstance(factor, int) and factor >= 1
735
736
    if kernel is None:
        kernel = [1] * factor
737

738
    kernel = torch.tensor(kernel, dtype=torch.float32)
739
    if kernel.ndim == 1:
740
741
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
742

743
    kernel = kernel * (gain * (factor**2))
744
    pad_value = kernel.shape[0] - factor
745
    output = upfirdn2d_native(
746
747
748
749
750
        hidden_states,
        kernel.to(device=hidden_states.device),
        up=factor,
        pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
    )
751
    return output
Patrick von Platen's avatar
Patrick von Platen committed
752
753


754
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
755
    r"""Downsample2D a batch of 2D images with the given filter.
Patrick von Platen's avatar
Patrick von Platen committed
756
757
758
759
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
    given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
    specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
    shape is a multiple of the downsampling factor.
760
761
762

    Args:
        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
763
        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
Patrick von Platen's avatar
Patrick von Platen committed
764
          (separable). The default is `[1] * factor`, which corresponds to average pooling.
765
766
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
Patrick von Platen's avatar
Patrick von Platen committed
767
768

    Returns:
769
        output: Tensor of the shape `[N, C, H // factor, W // factor]`
Patrick von Platen's avatar
Patrick von Platen committed
770
771
772
    """

    assert isinstance(factor, int) and factor >= 1
773
774
    if kernel is None:
        kernel = [1] * factor
Patrick von Platen's avatar
Patrick von Platen committed
775

776
    kernel = torch.tensor(kernel, dtype=torch.float32)
777
    if kernel.ndim == 1:
778
779
        kernel = torch.outer(kernel, kernel)
    kernel /= torch.sum(kernel)
780

781
    kernel = kernel * gain
782
    pad_value = kernel.shape[0] - factor
783
    output = upfirdn2d_native(
784
785
        hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
    )
786
    return output
787
788


789
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
790
791
792
793
794
    up_x = up_y = up
    down_x = down_y = down
    pad_x0 = pad_y0 = pad[0]
    pad_x1 = pad_y1 = pad[1]

795
796
    _, channel, in_h, in_w = tensor.shape
    tensor = tensor.reshape(-1, in_h, in_w, 1)
797

798
    _, in_h, in_w, minor = tensor.shape
799
800
    kernel_h, kernel_w = kernel.shape

801
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
802
803
804
805
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
806
    out = out.to(tensor.device)  # Move back to mps if necessary
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888


class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
    """

    def __init__(self, in_dim, out_dim=None, dropout=0.0):
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
            nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
        )
        self.conv2 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

    def forward(self, hidden_states, num_frames=1):
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states