resnet.py 33.9 KB
Newer Older
Aryan's avatar
Aryan committed
1
2
# Copyright 2025 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2025 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
from functools import partial
liumg's avatar
liumg committed
17
from typing import Optional, Tuple, Union, List
Patrick von Platen's avatar
Patrick von Platen committed
18

19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F

23
from ..utils import deprecate
24
from .activations import get_activation
YiYi Xu's avatar
YiYi Xu committed
25
from .attention_processor import SpatialNorm
26
27
28
29
30
31
32
from .downsampling import (  # noqa
    Downsample1D,
    Downsample2D,
    FirDownsample2D,
    KDownsample2D,
    downsample_2d,
)
33
from .normalization import AdaGroupNorm
34
35
36
37
38
39
40
41
from .upsampling import (  # noqa
    FirUpsample2D,
    KUpsample2D,
    Upsample1D,
    Upsample2D,
    upfirdn2d_native,
    upsample_2d,
)
42

liumg's avatar
liumg committed
43
44
45
46
from torch.nn import FuseGroupNorm as GroupNorm
from ..custom_op import ConvBiasAdd as ConvBiasAdd
from ..custom_op import ConvBias as ConvBias
from ..custom_op import miopenGroupNorm
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class ResnetBlockCondNorm2D(nn.Module):
    r"""
    A Resnet block that use normalization layer that incorporate conditioning information.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
            The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
65
        kernel (`torch.Tensor`, optional, default to None): FIR filter, see
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

    def __init__(
        self,
        *,
        in_channels: int,
        out_channels: Optional[int] = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        temb_channels: int = 512,
        groups: int = 32,
        groups_out: Optional[int] = None,
        eps: float = 1e-6,
        non_linearity: str = "swish",
        time_embedding_norm: str = "ada_group",  # ada_group, spatial
        output_scale_factor: float = 1.0,
        use_in_shortcut: Optional[bool] = None,
        up: bool = False,
        down: bool = False,
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
    ):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
        self.time_embedding_norm = time_embedding_norm

        if groups_out is None:
            groups_out = groups

        if self.time_embedding_norm == "ada_group":  # ada_group
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
        elif self.time_embedding_norm == "spatial":
            self.norm1 = SpatialNorm(in_channels, temb_channels)
        else:
            raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")

118
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
119
120
121
122
123
124
125
126
127
128
129

        if self.time_embedding_norm == "ada_group":  # ada_group
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
        elif self.time_embedding_norm == "spatial":  # spatial
            self.norm2 = SpatialNorm(out_channels, temb_channels)
        else:
            raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")

        self.dropout = torch.nn.Dropout(dropout)

        conv_2d_out_channels = conv_2d_out_channels or out_channels
130
        self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
131
132
133
134
135
136
137
138
139
140
141
142
143

        self.nonlinearity = get_activation(non_linearity)

        self.upsample = self.downsample = None
        if self.up:
            self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut

        self.conv_shortcut = None
        if self.use_in_shortcut:
144
            self.conv_shortcut = nn.Conv2d(
145
146
147
148
149
150
151
152
                in_channels,
                conv_2d_out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=conv_shortcut_bias,
            )

153
    def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
154
155
156
157
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            deprecate("scale", "1.0.0", deprecation_message)

158
159
160
161
162
163
164
165
166
167
168
        hidden_states = input_tensor

        hidden_states = self.norm1(hidden_states, temb)

        hidden_states = self.nonlinearity(hidden_states)

        if self.upsample is not None:
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
169
170
            input_tensor = self.upsample(input_tensor)
            hidden_states = self.upsample(hidden_states)
171
172

        elif self.downsample is not None:
173
174
            input_tensor = self.downsample(input_tensor)
            hidden_states = self.downsample(hidden_states)
175

176
        hidden_states = self.conv1(hidden_states)
177
178
179
180
181
182

        hidden_states = self.norm2(hidden_states, temb)

        hidden_states = self.nonlinearity(hidden_states)

        hidden_states = self.dropout(hidden_states)
183
        hidden_states = self.conv2(hidden_states)
184
185

        if self.conv_shortcut is not None:
186
            input_tensor = self.conv_shortcut(input_tensor)
187
188
189
190
191
192

        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

        return output_tensor


193
class ResnetBlock2D(nn.Module):
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
209
210
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
            stronger conditioning with scale and shift.
211
        kernel (`torch.Tensor`, optional, default to None): FIR filter, see
212
213
214
215
216
217
218
219
220
221
222
223
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

224
225
226
    def __init__(
        self,
        *,
227
228
229
230
231
232
233
234
235
236
237
        in_channels: int,
        out_channels: Optional[int] = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        temb_channels: int = 512,
        groups: int = 32,
        groups_out: Optional[int] = None,
        pre_norm: bool = True,
        eps: float = 1e-6,
        non_linearity: str = "swish",
        skip_time_act: bool = False,
238
        time_embedding_norm: str = "default",  # default, scale_shift,
239
        kernel: Optional[torch.Tensor] = None,
240
241
242
243
        output_scale_factor: float = 1.0,
        use_in_shortcut: Optional[bool] = None,
        up: bool = False,
        down: bool = False,
244
245
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
246
247
    ):
        super().__init__()
248
249
250
251
252
253
254
255
256
        if time_embedding_norm == "ada_group":
            raise ValueError(
                "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead",
            )
        if time_embedding_norm == "spatial":
            raise ValueError(
                "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead",
            )

257
258
259
260
261
262
263
264
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
265
        self.time_embedding_norm = time_embedding_norm
266
        self.skip_time_act = skip_time_act
267
268
269
270

        if groups_out is None:
            groups_out = groups

liumg's avatar
liumg committed
271
272
273
274
275
276
277
        # self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
        # self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        # DCU OPT: gn_silu
        self.norm1 = miopenGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps,mode=10)
        # DCU OPT: conv_bias
        self.conv1 = ConvBias(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
278

279
        if temb_channels is not None:
Will Berman's avatar
Will Berman committed
280
            if self.time_embedding_norm == "default":
281
                self.time_emb_proj = nn.Linear(temb_channels, out_channels)
Will Berman's avatar
Will Berman committed
282
            elif self.time_embedding_norm == "scale_shift":
283
                self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
Will Berman's avatar
Will Berman committed
284
285
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
286
287
        else:
            self.time_emb_proj = None
288

liumg's avatar
liumg committed
289
290
291
292
293
294
        # self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        # DCU OPT: gn_silu
        if self.time_embedding_norm == "scale_shift" or self.time_embedding_norm == "default":
            self.norm2 = miopenGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps,mode=1)
        else:
            self.norm2 = miopenGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps,mode=10)
295

296
        self.dropout = torch.nn.Dropout(dropout)
297
        conv_2d_out_channels = conv_2d_out_channels or out_channels
liumg's avatar
liumg committed
298
299
300
        # self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
        # DCU OPT: conv_bias_add
        self.conv2 = ConvBiasAdd(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
301

302
        self.nonlinearity = get_activation(non_linearity)
303
304
305
306
307

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
308
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
309
310
311
312
313
314
315
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
316
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
317
318
319
320
321
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

322
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
323
324

        self.conv_shortcut = None
325
        if self.use_in_shortcut:
liumg's avatar
liumg committed
326
327
328
329
330
331
332
333
334
335
            # self.conv_shortcut = nn.Conv2d(
            #     in_channels,
            #     conv_2d_out_channels,
            #     kernel_size=1,
            #     stride=1,
            #     padding=0,
            #     bias=conv_shortcut_bias,
            # )
            # DCU OPT: conv_bias_add
            self.conv_shortcut = ConvBiasAdd(
Suraj Patil's avatar
Suraj Patil committed
336
337
338
339
340
341
                in_channels,
                conv_2d_out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=conv_shortcut_bias,
342
            )
343

344
    def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
345
346
347
348
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            deprecate("scale", "1.0.0", deprecation_message)

349
        hidden_states = input_tensor
350

liumg's avatar
liumg committed
351
352
353
        # hidden_states = self.norm1(hidden_states)
        # hidden_states = self.nonlinearity(hidden_states)
        # DCU OPT: gn_silu
354
        hidden_states = self.norm1(hidden_states)
liumg's avatar
liumg committed
355
         
356
357

        if self.upsample is not None:
358
359
360
361
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
362
363
            input_tensor = self.upsample(input_tensor)
            hidden_states = self.upsample(hidden_states)
364
        elif self.downsample is not None:
365
366
            input_tensor = self.downsample(input_tensor)
            hidden_states = self.downsample(hidden_states)
367

368
        hidden_states = self.conv1(hidden_states)
369

370
        if self.time_emb_proj is not None:
371
372
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
liumg's avatar
liumg committed
373
374
375
376
            #temb = self.time_emb_proj(temb)[:, :, None, None]
            # DCU OPT: TN->NN
            x = torch.matmul(temb, self.time_emb_proj.weight.data) + self.time_emb_proj.bias.data
            temb = x[:,:,None,None]
Will Berman's avatar
Will Berman committed
377

378
379
380
        if self.time_embedding_norm == "default":
            if temb is not None:
                hidden_states = hidden_states + temb
liumg's avatar
liumg committed
381
            # DCU OPT: ori                                                                                                                       
382
            hidden_states = self.norm2(hidden_states)
liumg's avatar
liumg committed
383
            hidden_states = self.nonlinearity(hidden_states)
384
385
386
387
388
        elif self.time_embedding_norm == "scale_shift":
            if temb is None:
                raise ValueError(
                    f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
                )
389
            time_scale, time_shift = torch.chunk(temb, 2, dim=1)
liumg's avatar
liumg committed
390
            # DCU OPT: ori
391
            hidden_states = self.norm2(hidden_states)
392
            hidden_states = hidden_states * (1 + time_scale) + time_shift
liumg's avatar
liumg committed
393
            hidden_states = self.nonlinearity(hidden_states)
394
        else:
liumg's avatar
liumg committed
395
            # DCU OPT: gn_silu
396
            hidden_states = self.norm2(hidden_states)
liumg's avatar
liumg committed
397
            
Will Berman's avatar
Will Berman committed
398

liumg's avatar
liumg committed
399
        # hidden_states = self.nonlinearity(hidden_states)
400

401
        hidden_states = self.dropout(hidden_states)
liumg's avatar
liumg committed
402
403
404
        
        # origin
        """
405
        hidden_states = self.conv2(hidden_states)
406
407

        if self.conv_shortcut is not None:
408
            input_tensor = self.conv_shortcut(input_tensor.contiguous())
409

410
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
liumg's avatar
liumg committed
411
412
413
414
415
416
417
418
419
420
        """
        # DCU OPT: conv_bias_add
        if self.conv_shortcut is not None:
            tmp_add=torch.zeros_like(hidden_states, memory_format=torch.channels_last)
            hidden_states = self.conv2(hidden_states,tmp_add)
            input_tensor = self.conv_shortcut(input_tensor.contiguous(memory_format=torch.channels_last),hidden_states)
            output_tensor = input_tensor / self.output_scale_factor
        else:
            hidden_states = self.conv2(hidden_states,input_tensor)
            output_tensor = hidden_states / self.output_scale_factor
421

422
        return output_tensor
423

Patrick von Platen's avatar
Patrick von Platen committed
424

425
# unet_rl.py
426
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
427
428
429
430
431
432
433
434
435
436
437
438
439
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
440
441
442
443
444
445

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
        n_groups (`int`, default `8`): Number of groups to separate the channels into.
446
        activation (`str`, defaults to `mish`): Name of the activation function.
447
448
    """

449
    def __init__(
450
451
452
453
454
455
        self,
        inp_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        n_groups: int = 8,
        activation: str = "mish",
456
    ):
457
458
459
460
        super().__init__()

        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
461
        self.mish = get_activation(activation)
462

463
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
464
465
466
467
468
469
        intermediate_repr = self.conv1d(inputs)
        intermediate_repr = rearrange_dims(intermediate_repr)
        intermediate_repr = self.group_norm(intermediate_repr)
        intermediate_repr = rearrange_dims(intermediate_repr)
        output = self.mish(intermediate_repr)
        return output
470
471
472
473


# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
474
475
476
477
478
479
480
481
    """
    Residual 1D block with temporal convolutions.

    Parameters:
        inp_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        embed_dim (`int`): Embedding dimension.
        kernel_size (`int` or `tuple`): Size of the convolving kernel.
482
        activation (`str`, defaults `mish`): It is possible to choose the right activation function.
483
484
485
    """

    def __init__(
486
487
488
489
490
491
        self,
        inp_channels: int,
        out_channels: int,
        embed_dim: int,
        kernel_size: Union[int, Tuple[int, int]] = 5,
        activation: str = "mish",
492
    ):
493
494
495
496
        super().__init__()
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)

497
        self.time_emb_act = get_activation(activation)
498
499
500
501
502
503
        self.time_emb = nn.Linear(embed_dim, out_channels)

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )

504
    def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
505
506
        """
        Args:
507
            inputs : [ batch_size x inp_channels x horizon ]
508
509
510
511
512
513
514
            t : [ batch_size x embed_dim ]

        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        t = self.time_emb_act(t)
        t = self.time_emb(t)
515
        out = self.conv_in(inputs) + rearrange_dims(t)
516
        out = self.conv_out(out)
517
        return out + self.residual_conv(inputs)
518
519


520
521
522
523
class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
524
525
526
527
528

    Parameters:
        in_dim (`int`): Number of input channels.
        out_dim (`int`): Number of output channels.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
529
530
    """

Suraj Patil's avatar
Suraj Patil committed
531
532
533
534
535
536
537
    def __init__(
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
    ):
538
539
540
541
542
543
544
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim

        # conv layers
        self.conv1 = nn.Sequential(
Suraj Patil's avatar
Suraj Patil committed
545
546
547
            nn.GroupNorm(norm_num_groups, in_dim),
            nn.SiLU(),
            nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
548
549
        )
        self.conv2 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
550
            nn.GroupNorm(norm_num_groups, out_dim),
551
552
553
554
555
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
556
            nn.GroupNorm(norm_num_groups, out_dim),
557
558
559
560
561
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
Dhruv Nair's avatar
Dhruv Nair committed
562
            nn.GroupNorm(norm_num_groups, out_dim),
563
564
565
566
567
568
569
570
571
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)

572
    def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )

        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)

        hidden_states = identity + hidden_states

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        return hidden_states
Suraj Patil's avatar
Suraj Patil committed
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656


class TemporalResnetBlock(nn.Module):
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        temb_channels: int = 512,
        eps: float = 1e-6,
    ):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels

        kernel_size = (3, 1, 1)
        padding = [k // 2 for k in kernel_size]

        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
        self.conv1 = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
        )

        if temb_channels is not None:
            self.time_emb_proj = nn.Linear(temb_channels, out_channels)
        else:
            self.time_emb_proj = None

        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)

        self.dropout = torch.nn.Dropout(0.0)
        self.conv2 = nn.Conv3d(
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
        )

        self.nonlinearity = get_activation("silu")

        self.use_in_shortcut = self.in_channels != out_channels

        self.conv_shortcut = None
        if self.use_in_shortcut:
            self.conv_shortcut = nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
            )

657
    def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
Suraj Patil's avatar
Suraj Patil committed
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        hidden_states = input_tensor

        hidden_states = self.norm1(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.conv1(hidden_states)

        if self.time_emb_proj is not None:
            temb = self.nonlinearity(temb)
            temb = self.time_emb_proj(temb)[:, :, :, None, None]
            temb = temb.permute(0, 2, 1, 3, 4)
            hidden_states = hidden_states + temb

        hidden_states = self.norm2(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.conv_shortcut is not None:
            input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = input_tensor + hidden_states

        return output_tensor


# VideoResBlock
class SpatioTemporalResBlock(nn.Module):
    r"""
    A SpatioTemporal Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
        temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
        merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
            The merge strategy to use for the temporal mixing.
        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
            If `True`, switch the spatial and temporal mixing.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        temb_channels: int = 512,
        eps: float = 1e-6,
        temporal_eps: Optional[float] = None,
        merge_factor: float = 0.5,
        merge_strategy="learned_with_images",
        switch_spatial_to_temporal_mix: bool = False,
    ):
        super().__init__()

        self.spatial_res_block = ResnetBlock2D(
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            eps=eps,
        )

        self.temporal_res_block = TemporalResnetBlock(
            in_channels=out_channels if out_channels is not None else in_channels,
            out_channels=out_channels if out_channels is not None else in_channels,
            temb_channels=temb_channels,
            eps=temporal_eps if temporal_eps is not None else eps,
        )

        self.time_mixer = AlphaBlender(
            alpha=merge_factor,
            merge_strategy=merge_strategy,
            switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
        )

    def forward(
        self,
737
738
        hidden_states: torch.Tensor,
        temb: Optional[torch.Tensor] = None,
Suraj Patil's avatar
Suraj Patil committed
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
        image_only_indicator: Optional[torch.Tensor] = None,
    ):
        num_frames = image_only_indicator.shape[-1]
        hidden_states = self.spatial_res_block(hidden_states, temb)

        batch_frames, channels, height, width = hidden_states.shape
        batch_size = batch_frames // num_frames

        hidden_states_mix = (
            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
        )
        hidden_states = (
            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
        )

        if temb is not None:
            temb = temb.reshape(batch_size, num_frames, -1)

        hidden_states = self.temporal_res_block(hidden_states, temb)
        hidden_states = self.time_mixer(
            x_spatial=hidden_states_mix,
            x_temporal=hidden_states,
            image_only_indicator=image_only_indicator,
        )

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
        return hidden_states


class AlphaBlender(nn.Module):
    r"""
    A module to blend spatial and temporal features.

    Parameters:
        alpha (`float`): The initial value of the blending factor.
        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
            The merge strategy to use for the temporal mixing.
        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
            If `True`, switch the spatial and temporal mixing.
    """

    strategies = ["learned", "fixed", "learned_with_images"]

    def __init__(
        self,
        alpha: float,
        merge_strategy: str = "learned_with_images",
        switch_spatial_to_temporal_mix: bool = False,
    ):
        super().__init__()
        self.merge_strategy = merge_strategy
        self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix  # For TemporalVAE

        if merge_strategy not in self.strategies:
            raise ValueError(f"merge_strategy needs to be in {self.strategies}")

        if self.merge_strategy == "fixed":
            self.register_buffer("mix_factor", torch.Tensor([alpha]))
        elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
            self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
        else:
            raise ValueError(f"Unknown merge strategy {self.merge_strategy}")

    def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
        if self.merge_strategy == "fixed":
            alpha = self.mix_factor

        elif self.merge_strategy == "learned":
            alpha = torch.sigmoid(self.mix_factor)

        elif self.merge_strategy == "learned_with_images":
            if image_only_indicator is None:
                raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")

            alpha = torch.where(
                image_only_indicator.bool(),
                torch.ones(1, 1, device=image_only_indicator.device),
                torch.sigmoid(self.mix_factor)[..., None],
            )

            # (batch, channel, frames, height, width)
            if ndims == 5:
                alpha = alpha[:, None, :, None, None]
            # (batch*frames, height*width, channels)
            elif ndims == 3:
                alpha = alpha.reshape(-1)[:, None, None]
            else:
                raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")

        else:
            raise NotImplementedError

        return alpha

    def forward(
        self,
        x_spatial: torch.Tensor,
        x_temporal: torch.Tensor,
        image_only_indicator: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
        alpha = alpha.to(x_spatial.dtype)

        if self.switch_spatial_to_temporal_mix:
            alpha = 1.0 - alpha

        x = alpha * x_spatial + (1.0 - alpha) * x_temporal
        return x