attention.py 71.4 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
17

import torch
18
import torch.nn as nn
Will Berman's avatar
Will Berman committed
19
import torch.nn.functional as F
20

21
from ..utils import deprecate, logging
22
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
Dhruv Nair's avatar
Dhruv Nair committed
23
from ..utils.torch_utils import maybe_allow_in_graph
Aryan's avatar
Aryan committed
24
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
25
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
Dhruv Nair's avatar
Dhruv Nair committed
26
from .embeddings import SinusoidalPositionalEmbedding
YiYi Xu's avatar
YiYi Xu committed
27
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
28
29


30
31
32
33
34
35
if is_xformers_available():
    import xformers as xops
else:
    xops = None


36
37
38
logger = logging.get_logger(__name__)


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
class AttentionMixin:
    @property
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)

    def fuse_qkv_projections(self):
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.
        """
        for _, attn_processor in self.attn_processors.items():
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        for module in self.modules():
            if isinstance(module, AttentionModuleMixin):
                module.fuse_projections()

    def unfuse_qkv_projections(self):
        """Disables the fused QKV projection if enabled.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>
        """
        for module in self.modules():
            if isinstance(module, AttentionModuleMixin):
                module.unfuse_projections()


class AttentionModuleMixin:
    _default_processor_cls = None
    _available_processors = []
    fused_projections = False

    def set_processor(self, processor: AttentionProcessor) -> None:
        """
        Set the attention processor to use.

        Args:
            processor (`AttnProcessor`):
                The attention processor to use.
        """
        # if current processor is in `self._modules` and if passed `processor` is not, we need to
        # pop `processor` from `self._modules`
        if (
            hasattr(self, "processor")
            and isinstance(self.processor, torch.nn.Module)
            and not isinstance(processor, torch.nn.Module)
        ):
            logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
            self._modules.pop("processor")

        self.processor = processor

    def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
        """
        Get the attention processor in use.

        Args:
            return_deprecated_lora (`bool`, *optional*, defaults to `False`):
                Set to `True` to return the deprecated LoRA attention processor.

        Returns:
            "AttentionProcessor": The attention processor in use.
        """
        if not return_deprecated_lora:
            return self.processor

    def set_attention_backend(self, backend: str):
        from .attention_dispatch import AttentionBackendName

        available_backends = {x.value for x in AttentionBackendName.__members__.values()}
        if backend not in available_backends:
            raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))

        backend = AttentionBackendName(backend.lower())
        self.processor._attention_backend = backend

    def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
        """
        Set whether to use NPU flash attention from `torch_npu` or not.

        Args:
            use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
        """

        if use_npu_flash_attention:
            if not is_torch_npu_available():
                raise ImportError("torch_npu is not available")

        self.set_attention_backend("_native_npu")

    def set_use_xla_flash_attention(
        self,
        use_xla_flash_attention: bool,
        partition_spec: Optional[Tuple[Optional[str], ...]] = None,
        is_flux=False,
    ) -> None:
        """
        Set whether to use XLA flash attention from `torch_xla` or not.

        Args:
            use_xla_flash_attention (`bool`):
                Whether to use pallas flash attention kernel from `torch_xla` or not.
            partition_spec (`Tuple[]`, *optional*):
                Specify the partition specification if using SPMD. Otherwise None.
            is_flux (`bool`, *optional*, defaults to `False`):
                Whether the model is a Flux model.
        """
        if use_xla_flash_attention:
            if not is_torch_xla_available():
                raise ImportError("torch_xla is not available")

        self.set_attention_backend("_native_xla")

    def set_use_memory_efficient_attention_xformers(
        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
    ) -> None:
        """
        Set whether to use memory efficient attention from `xformers` or not.

        Args:
            use_memory_efficient_attention_xformers (`bool`):
                Whether to use memory efficient attention from `xformers` or not.
            attention_op (`Callable`, *optional*):
                The attention operation to use. Defaults to `None` which uses the default attention operation from
                `xformers`.
        """
        if use_memory_efficient_attention_xformers:
            if not is_xformers_available():
                raise ModuleNotFoundError(
                    "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
                    name="xformers",
                )
            elif not torch.cuda.is_available():
                raise ValueError(
                    "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
                    " only available for GPU "
                )
            else:
                try:
                    # Make sure we can run the memory efficient attention
                    if is_xformers_available():
                        dtype = None
                        if attention_op is not None:
                            op_fw, op_bw = attention_op
                            dtype, *_ = op_fw.SUPPORTED_DTYPES
                        q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
                        _ = xops.memory_efficient_attention(q, q, q)
                except Exception as e:
                    raise e

                self.set_attention_backend("xformers")

    @torch.no_grad()
    def fuse_projections(self):
        """
        Fuse the query, key, and value projections into a single projection for efficiency.
        """
        # Skip if already fused
        if getattr(self, "fused_projections", False):
            return

        device = self.to_q.weight.data.device
        dtype = self.to_q.weight.data.dtype

        if hasattr(self, "is_cross_attention") and self.is_cross_attention:
            # Fuse cross-attention key-value projections
            concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
            in_features = concatenated_weights.shape[1]
            out_features = concatenated_weights.shape[0]

            self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
            self.to_kv.weight.copy_(concatenated_weights)
            if hasattr(self, "use_bias") and self.use_bias:
                concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
                self.to_kv.bias.copy_(concatenated_bias)
        else:
            # Fuse self-attention projections
            concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
            in_features = concatenated_weights.shape[1]
            out_features = concatenated_weights.shape[0]

            self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
            self.to_qkv.weight.copy_(concatenated_weights)
            if hasattr(self, "use_bias") and self.use_bias:
                concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
                self.to_qkv.bias.copy_(concatenated_bias)

        # Handle added projections for models like SD3, Flux, etc.
        if (
            getattr(self, "add_q_proj", None) is not None
            and getattr(self, "add_k_proj", None) is not None
            and getattr(self, "add_v_proj", None) is not None
        ):
            concatenated_weights = torch.cat(
                [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
            )
            in_features = concatenated_weights.shape[1]
            out_features = concatenated_weights.shape[0]

            self.to_added_qkv = nn.Linear(
                in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
            )
            self.to_added_qkv.weight.copy_(concatenated_weights)
            if self.added_proj_bias:
                concatenated_bias = torch.cat(
                    [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
                )
                self.to_added_qkv.bias.copy_(concatenated_bias)

        self.fused_projections = True

    @torch.no_grad()
    def unfuse_projections(self):
        """
        Unfuse the query, key, and value projections back to separate projections.
        """
        # Skip if not fused
        if not getattr(self, "fused_projections", False):
            return

        # Remove fused projection layers
        if hasattr(self, "to_qkv"):
            delattr(self, "to_qkv")

        if hasattr(self, "to_kv"):
            delattr(self, "to_kv")

        if hasattr(self, "to_added_qkv"):
            delattr(self, "to_added_qkv")

        self.fused_projections = False

    def set_attention_slice(self, slice_size: int) -> None:
        """
        Set the slice size for attention computation.

        Args:
            slice_size (`int`):
                The slice size for attention computation.
        """
        if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
            raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")

        processor = None

        # Try to get a compatible processor for sliced attention
        if slice_size is not None:
            processor = self._get_compatible_processor("sliced")

        # If no processor was found or slice_size is None, use default processor
        if processor is None:
            processor = self.default_processor_cls()

        self.set_processor(processor)

    def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.

        Args:
            tensor (`torch.Tensor`): The tensor to reshape.

        Returns:
            `torch.Tensor`: The reshaped tensor.
        """
        head_size = self.heads
        batch_size, seq_len, dim = tensor.shape
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

    def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
        """
        Reshape the tensor for multi-head attention processing.

        Args:
            tensor (`torch.Tensor`): The tensor to reshape.
            out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.

        Returns:
            `torch.Tensor`: The reshaped tensor.
        """
        head_size = self.heads
        if tensor.ndim == 3:
            batch_size, seq_len, dim = tensor.shape
            extra_dim = 1
        else:
            batch_size, extra_dim, seq_len, dim = tensor.shape
        tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3)

        if out_dim == 3:
            tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)

        return tensor

    def get_attention_scores(
        self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute the attention scores.

        Args:
            query (`torch.Tensor`): The query tensor.
            key (`torch.Tensor`): The key tensor.
            attention_mask (`torch.Tensor`, *optional*): The attention mask to use.

        Returns:
            `torch.Tensor`: The attention probabilities/scores.
        """
        dtype = query.dtype
        if self.upcast_attention:
            query = query.float()
            key = key.float()

        if attention_mask is None:
            baddbmm_input = torch.empty(
                query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
            )
            beta = 0
        else:
            baddbmm_input = attention_mask
            beta = 1

        attention_scores = torch.baddbmm(
            baddbmm_input,
            query,
            key.transpose(-1, -2),
            beta=beta,
            alpha=self.scale,
        )
        del baddbmm_input

        if self.upcast_softmax:
            attention_scores = attention_scores.float()

        attention_probs = attention_scores.softmax(dim=-1)
        del attention_scores

        attention_probs = attention_probs.to(dtype)

        return attention_probs

    def prepare_attention_mask(
        self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
    ) -> torch.Tensor:
        """
        Prepare the attention mask for the attention computation.

        Args:
            attention_mask (`torch.Tensor`): The attention mask to prepare.
            target_length (`int`): The target length of the attention mask.
            batch_size (`int`): The batch size for repeating the attention mask.
            out_dim (`int`, *optional*, defaults to `3`): Output dimension.

        Returns:
            `torch.Tensor`: The prepared attention mask.
        """
        head_size = self.heads
        if attention_mask is None:
            return attention_mask

        current_length: int = attention_mask.shape[-1]
        if current_length != target_length:
            if attention_mask.device.type == "mps":
                # HACK: MPS: Does not support padding by greater than dimension of input tensor.
                # Instead, we can manually construct the padding tensor.
                padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
                padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
                attention_mask = torch.cat([attention_mask, padding], dim=2)
            else:
                # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
                #       we want to instead pad by (0, remaining_length), where remaining_length is:
                #       remaining_length: int = target_length - current_length
                # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

        if out_dim == 3:
            if attention_mask.shape[0] < batch_size * head_size:
                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
        elif out_dim == 4:
            attention_mask = attention_mask.unsqueeze(1)
            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)

        return attention_mask

    def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Normalize the encoder hidden states.

        Args:
            encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.

        Returns:
            `torch.Tensor`: The normalized encoder hidden states.
        """
        assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
        if isinstance(self.norm_cross, nn.LayerNorm):
            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
        elif isinstance(self.norm_cross, nn.GroupNorm):
            # Group norm norms along the channels dimension and expects
            # input to be in the shape of (N, C, *). In this case, we want
            # to norm along the hidden dimension, so we need to move
            # (batch_size, sequence_length, hidden_size) ->
            # (batch_size, hidden_size, sequence_length)
            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
        else:
            assert False

        return encoder_hidden_states


512
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
Suraj Patil's avatar
Suraj Patil committed
513
514
515
516
517
518
519
    # "feed_forward_chunk_size" can be used to save memory
    if hidden_states.shape[chunk_dim] % chunk_size != 0:
        raise ValueError(
            f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
        )

    num_chunks = hidden_states.shape[chunk_dim] // chunk_size
520
521
522
523
    ff_output = torch.cat(
        [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
        dim=chunk_dim,
    )
Suraj Patil's avatar
Suraj Patil committed
524
525
526
    return ff_output


527
528
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
529
530
531
532
533
534
535
536
537
538
539
    r"""
    A gated self-attention dense layer that combines visual features and object features.

    Parameters:
        query_dim (`int`): The number of channels in the query.
        context_dim (`int`): The number of channels in the context.
        n_heads (`int`): The number of heads to use for attention.
        d_head (`int`): The number of channels in each head.
    """

    def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        super().__init__()

        # we need a linear projection since we need cat visual feature and obj feature
        self.linear = nn.Linear(context_dim, query_dim)

        self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
        self.ff = FeedForward(query_dim, activation_fn="geglu")

        self.norm1 = nn.LayerNorm(query_dim)
        self.norm2 = nn.LayerNorm(query_dim)

        self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
        self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))

        self.enabled = True

556
    def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
557
558
559
560
561
562
563
564
565
566
567
568
        if not self.enabled:
            return x

        n_visual = x.shape[1]
        objs = self.linear(objs)

        x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
        x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))

        return x


Dhruv Nair's avatar
Dhruv Nair committed
569
570
571
572
573
@maybe_allow_in_graph
class JointTransformerBlock(nn.Module):
    r"""
    A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.

Quentin Gallouédec's avatar
Quentin Gallouédec committed
574
    Reference: https://huggingface.co/papers/2403.03206
Dhruv Nair's avatar
Dhruv Nair committed
575
576
577
578
579
580
581
582
583

    Parameters:
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
            processing of `context` conditions.
    """

YiYi Xu's avatar
YiYi Xu committed
584
585
586
587
588
589
590
591
592
    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        context_pre_only: bool = False,
        qk_norm: Optional[str] = None,
        use_dual_attention: bool = False,
    ):
Dhruv Nair's avatar
Dhruv Nair committed
593
594
        super().__init__()

YiYi Xu's avatar
YiYi Xu committed
595
        self.use_dual_attention = use_dual_attention
Dhruv Nair's avatar
Dhruv Nair committed
596
597
598
        self.context_pre_only = context_pre_only
        context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"

YiYi Xu's avatar
YiYi Xu committed
599
600
601
602
        if use_dual_attention:
            self.norm1 = SD35AdaLayerNormZeroX(dim)
        else:
            self.norm1 = AdaLayerNormZero(dim)
Dhruv Nair's avatar
Dhruv Nair committed
603
604
605
606
607
608
609
610
611
612
613

        if context_norm_type == "ada_norm_continous":
            self.norm1_context = AdaLayerNormContinuous(
                dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
            )
        elif context_norm_type == "ada_norm_zero":
            self.norm1_context = AdaLayerNormZero(dim)
        else:
            raise ValueError(
                f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
            )
YiYi Xu's avatar
YiYi Xu committed
614

Dhruv Nair's avatar
Dhruv Nair committed
615
616
617
618
619
620
        if hasattr(F, "scaled_dot_product_attention"):
            processor = JointAttnProcessor2_0()
        else:
            raise ValueError(
                "The current PyTorch version does not support the `scaled_dot_product_attention` function."
            )
YiYi Xu's avatar
YiYi Xu committed
621

Dhruv Nair's avatar
Dhruv Nair committed
622
623
624
625
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            added_kv_proj_dim=dim,
626
            dim_head=attention_head_dim,
Dhruv Nair's avatar
Dhruv Nair committed
627
            heads=num_attention_heads,
628
            out_dim=dim,
Dhruv Nair's avatar
Dhruv Nair committed
629
630
631
            context_pre_only=context_pre_only,
            bias=True,
            processor=processor,
YiYi Xu's avatar
YiYi Xu committed
632
633
            qk_norm=qk_norm,
            eps=1e-6,
Dhruv Nair's avatar
Dhruv Nair committed
634
635
        )

YiYi Xu's avatar
YiYi Xu committed
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        if use_dual_attention:
            self.attn2 = Attention(
                query_dim=dim,
                cross_attention_dim=None,
                dim_head=attention_head_dim,
                heads=num_attention_heads,
                out_dim=dim,
                bias=True,
                processor=processor,
                qk_norm=qk_norm,
                eps=1e-6,
            )
        else:
            self.attn2 = None

Dhruv Nair's avatar
Dhruv Nair committed
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

        if not context_pre_only:
            self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
            self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
        else:
            self.norm2_context = None
            self.ff_context = None

        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

    # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        self._chunk_dim = dim

    def forward(
672
673
674
675
676
        self,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor,
        temb: torch.FloatTensor,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
Dhruv Nair's avatar
Dhruv Nair committed
677
    ):
678
        joint_attention_kwargs = joint_attention_kwargs or {}
YiYi Xu's avatar
YiYi Xu committed
679
680
681
682
683
684
        if self.use_dual_attention:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
                hidden_states, emb=temb
            )
        else:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
Dhruv Nair's avatar
Dhruv Nair committed
685
686
687
688
689
690
691
692
693
694

        if self.context_pre_only:
            norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
        else:
            norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
                encoder_hidden_states, emb=temb
            )

        # Attention.
        attn_output, context_attn_output = self.attn(
695
696
697
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            **joint_attention_kwargs,
Dhruv Nair's avatar
Dhruv Nair committed
698
699
700
701
702
703
        )

        # Process attention outputs for the `hidden_states`.
        attn_output = gate_msa.unsqueeze(1) * attn_output
        hidden_states = hidden_states + attn_output

YiYi Xu's avatar
YiYi Xu committed
704
        if self.use_dual_attention:
705
            attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
YiYi Xu's avatar
YiYi Xu committed
706
707
708
            attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
            hidden_states = hidden_states + attn_output2

Dhruv Nair's avatar
Dhruv Nair committed
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        norm_hidden_states = self.norm2(hidden_states)
        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
        if self._chunk_size is not None:
            # "feed_forward_chunk_size" can be used to save memory
            ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
        else:
            ff_output = self.ff(norm_hidden_states)
        ff_output = gate_mlp.unsqueeze(1) * ff_output

        hidden_states = hidden_states + ff_output

        # Process attention outputs for the `encoder_hidden_states`.
        if self.context_pre_only:
            encoder_hidden_states = None
        else:
            context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
            encoder_hidden_states = encoder_hidden_states + context_attn_output

            norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
            norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
            if self._chunk_size is not None:
                # "feed_forward_chunk_size" can be used to save memory
                context_ff_output = _chunked_feed_forward(
                    self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
                )
            else:
                context_ff_output = self.ff_context(norm_encoder_hidden_states)
            encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output

        return encoder_hidden_states, hidden_states


741
@maybe_allow_in_graph
Patrick von Platen's avatar
Patrick von Platen committed
742
class BasicTransformerBlock(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
743
744
745
746
    r"""
    A basic Transformer block.

    Parameters:
Will Berman's avatar
Will Berman committed
747
748
749
750
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
Will Berman's avatar
Will Berman committed
751
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
Will Berman's avatar
Will Berman committed
752
753
754
755
756
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm (:
            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:
            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
757
758
759
760
761
762
763
764
765
766
767
768
769
770
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used.
        double_self_attention (`bool`, *optional*):
            Whether to use two self-attention layers. In this case no cross attention layers are used.
        upcast_attention (`bool`, *optional*):
            Whether to upcast the attention computation to float32. This is useful for mixed precision training.
        norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
            Whether to use learnable elementwise affine parameters for normalization.
        norm_type (`str`, *optional*, defaults to `"layer_norm"`):
            The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
        final_dropout (`bool` *optional*, defaults to False):
            Whether to apply a final dropout after the last feed-forward layer.
        attention_type (`str`, *optional*, defaults to `"default"`):
            The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
Dhruv Nair's avatar
Dhruv Nair committed
771
772
773
774
        positional_embeddings (`str`, *optional*, defaults to `None`):
            The type of positional embeddings to apply to.
        num_positional_embeddings (`int`, *optional*, defaults to `None`):
            The maximum number of positional embeddings to apply.
Kashif Rasul's avatar
Kashif Rasul committed
775
776
777
778
779
    """

    def __init__(
        self,
        dim: int,
Will Berman's avatar
Will Berman committed
780
781
        num_attention_heads: int,
        attention_head_dim: int,
Kashif Rasul's avatar
Kashif Rasul committed
782
        dropout=0.0,
Will Berman's avatar
Will Berman committed
783
784
785
786
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
787
        only_cross_attention: bool = False,
788
        double_self_attention: bool = False,
789
        upcast_attention: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
790
        norm_elementwise_affine: bool = True,
791
        norm_type: str = "layer_norm",  # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
Sayak Paul's avatar
Sayak Paul committed
792
        norm_eps: float = 1e-5,
Kashif Rasul's avatar
Kashif Rasul committed
793
        final_dropout: bool = False,
794
        attention_type: str = "default",
Dhruv Nair's avatar
Dhruv Nair committed
795
796
        positional_embeddings: Optional[str] = None,
        num_positional_embeddings: Optional[int] = None,
Will Berman's avatar
Will Berman committed
797
798
799
800
801
        ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
        ada_norm_bias: Optional[int] = None,
        ff_inner_dim: Optional[int] = None,
        ff_bias: bool = True,
        attention_out_bias: bool = True,
Kashif Rasul's avatar
Kashif Rasul committed
802
    ):
Patrick von Platen's avatar
Patrick von Platen committed
803
        super().__init__()
Aryan's avatar
Aryan committed
804
805
806
807
808
809
810
811
812
813
814
        self.dim = dim
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        self.dropout = dropout
        self.cross_attention_dim = cross_attention_dim
        self.activation_fn = activation_fn
        self.attention_bias = attention_bias
        self.double_self_attention = double_self_attention
        self.norm_elementwise_affine = norm_elementwise_affine
        self.positional_embeddings = positional_embeddings
        self.num_positional_embeddings = num_positional_embeddings
815
        self.only_cross_attention = only_cross_attention
Kashif Rasul's avatar
Kashif Rasul committed
816

817
        # We keep these boolean flags for backward-compatibility.
818
819
820
821
822
823
        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
        self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
        self.use_layer_norm = norm_type == "layer_norm"
        self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"

Kashif Rasul's avatar
Kashif Rasul committed
824
825
826
827
828
        if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
            raise ValueError(
                f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
                f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
            )
829

830
831
832
        self.norm_type = norm_type
        self.num_embeds_ada_norm = num_embeds_ada_norm

Dhruv Nair's avatar
Dhruv Nair committed
833
834
835
836
837
838
839
840
841
842
        if positional_embeddings and (num_positional_embeddings is None):
            raise ValueError(
                "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
            )

        if positional_embeddings == "sinusoidal":
            self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
        else:
            self.pos_embed = None

843
        # Define 3 blocks. Each block has its own normalization layer.
844
        # 1. Self-Attn
845
        if norm_type == "ada_norm":
846
            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
847
        elif norm_type == "ada_norm_zero":
848
            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
849
        elif norm_type == "ada_norm_continuous":
Will Berman's avatar
Will Berman committed
850
851
852
853
854
855
856
857
            self.norm1 = AdaLayerNormContinuous(
                dim,
                ada_norm_continous_conditioning_embedding_dim,
                norm_elementwise_affine,
                norm_eps,
                ada_norm_bias,
                "rms_norm",
            )
858
        else:
Sayak Paul's avatar
Sayak Paul committed
859
860
            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

Patrick von Platen's avatar
Patrick von Platen committed
861
        self.attn1 = Attention(
Will Berman's avatar
Will Berman committed
862
863
864
865
866
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
867
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
868
            upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
869
            out_bias=attention_out_bias,
870
871
        )

872
        # 2. Cross-Attn
873
        if cross_attention_dim is not None or double_self_attention:
874
875
876
            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
            # the second cross attention block.
877
            if norm_type == "ada_norm":
Will Berman's avatar
Will Berman committed
878
                self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
879
            elif norm_type == "ada_norm_continuous":
Will Berman's avatar
Will Berman committed
880
881
882
883
884
885
886
887
888
889
890
                self.norm2 = AdaLayerNormContinuous(
                    dim,
                    ada_norm_continous_conditioning_embedding_dim,
                    norm_elementwise_affine,
                    norm_eps,
                    ada_norm_bias,
                    "rms_norm",
                )
            else:
                self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

Patrick von Platen's avatar
Patrick von Platen committed
891
            self.attn2 = Attention(
892
                query_dim=dim,
893
                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
894
895
896
897
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
898
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
899
                out_bias=attention_out_bias,
Will Berman's avatar
Will Berman committed
900
            )  # is self-attn if encoder_hidden_states is none
901
        else:
902
903
904
905
            if norm_type == "ada_norm_single":  # For Latte
                self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
            else:
                self.norm2 = None
906
            self.attn2 = None
907
908

        # 3. Feed-forward
909
        if norm_type == "ada_norm_continuous":
Will Berman's avatar
Will Berman committed
910
911
912
913
914
915
916
917
            self.norm3 = AdaLayerNormContinuous(
                dim,
                ada_norm_continous_conditioning_embedding_dim,
                norm_elementwise_affine,
                norm_eps,
                ada_norm_bias,
                "layer_norm",
            )
918

919
        elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
Will Berman's avatar
Will Berman committed
920
            self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
921
922
        elif norm_type == "layer_norm_i2vgen":
            self.norm3 = None
Sayak Paul's avatar
Sayak Paul committed
923

Suraj Patil's avatar
Suraj Patil committed
924
925
926
927
928
        self.ff = FeedForward(
            dim,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
Will Berman's avatar
Will Berman committed
929
930
            inner_dim=ff_inner_dim,
            bias=ff_bias,
Suraj Patil's avatar
Suraj Patil committed
931
        )
Patrick von Platen's avatar
Patrick von Platen committed
932

933
        # 4. Fuser
934
        if attention_type == "gated" or attention_type == "gated-text-image":
935
936
            self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)

Sayak Paul's avatar
Sayak Paul committed
937
        # 5. Scale-shift for PixArt-Alpha.
938
        if norm_type == "ada_norm_single":
Sayak Paul's avatar
Sayak Paul committed
939
940
            self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)

941
942
943
944
        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

Suraj Patil's avatar
Suraj Patil committed
945
    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
946
947
948
949
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        self._chunk_dim = dim

950
951
    def forward(
        self,
952
953
954
955
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
956
957
958
        timestep: Optional[torch.LongTensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        class_labels: Optional[torch.LongTensor] = None,
Will Berman's avatar
Will Berman committed
959
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
960
    ) -> torch.Tensor:
961
962
        if cross_attention_kwargs is not None:
            if cross_attention_kwargs.get("scale", None) is not None:
963
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
964

965
        # Notice that normalization is always applied before the real computation in the following blocks.
966
        # 0. Self-Attention
Sayak Paul's avatar
Sayak Paul committed
967
968
        batch_size = hidden_states.shape[0]

969
        if self.norm_type == "ada_norm":
Kashif Rasul's avatar
Kashif Rasul committed
970
            norm_hidden_states = self.norm1(hidden_states, timestep)
971
        elif self.norm_type == "ada_norm_zero":
Kashif Rasul's avatar
Kashif Rasul committed
972
973
974
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
            )
975
        elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
Kashif Rasul's avatar
Kashif Rasul committed
976
            norm_hidden_states = self.norm1(hidden_states)
977
        elif self.norm_type == "ada_norm_continuous":
Will Berman's avatar
Will Berman committed
978
            norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
979
        elif self.norm_type == "ada_norm_single":
Sayak Paul's avatar
Sayak Paul committed
980
981
982
983
984
985
986
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
                self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
            ).chunk(6, dim=1)
            norm_hidden_states = self.norm1(hidden_states)
            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
        else:
            raise ValueError("Incorrect norm used")
Kashif Rasul's avatar
Kashif Rasul committed
987

Dhruv Nair's avatar
Dhruv Nair committed
988
989
990
        if self.pos_embed is not None:
            norm_hidden_states = self.pos_embed(norm_hidden_states)

991
        # 1. Prepare GLIGEN inputs
992
993
        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
        gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
994

995
996
997
998
999
1000
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
1001

1002
        if self.norm_type == "ada_norm_zero":
Kashif Rasul's avatar
Kashif Rasul committed
1003
            attn_output = gate_msa.unsqueeze(1) * attn_output
1004
        elif self.norm_type == "ada_norm_single":
Sayak Paul's avatar
Sayak Paul committed
1005
1006
            attn_output = gate_msa * attn_output

1007
        hidden_states = attn_output + hidden_states
Sayak Paul's avatar
Sayak Paul committed
1008
1009
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)
Will Berman's avatar
Will Berman committed
1010

1011
        # 1.2 GLIGEN Control
1012
1013
1014
        if gligen_kwargs is not None:
            hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])

1015
        # 3. Cross-Attention
1016
        if self.attn2 is not None:
1017
            if self.norm_type == "ada_norm":
Sayak Paul's avatar
Sayak Paul committed
1018
                norm_hidden_states = self.norm2(hidden_states, timestep)
1019
            elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
Sayak Paul's avatar
Sayak Paul committed
1020
                norm_hidden_states = self.norm2(hidden_states)
1021
            elif self.norm_type == "ada_norm_single":
Sayak Paul's avatar
Sayak Paul committed
1022
1023
1024
                # For PixArt norm2 isn't applied here:
                # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
                norm_hidden_states = hidden_states
1025
            elif self.norm_type == "ada_norm_continuous":
Will Berman's avatar
Will Berman committed
1026
                norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
Sayak Paul's avatar
Sayak Paul committed
1027
1028
1029
            else:
                raise ValueError("Incorrect norm")

1030
            if self.pos_embed is not None and self.norm_type != "ada_norm_single":
Dhruv Nair's avatar
Dhruv Nair committed
1031
                norm_hidden_states = self.pos_embed(norm_hidden_states)
Kashif Rasul's avatar
Kashif Rasul committed
1032

1033
1034
1035
            attn_output = self.attn2(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states,
1036
                attention_mask=encoder_attention_mask,
1037
                **cross_attention_kwargs,
Will Berman's avatar
Will Berman committed
1038
            )
1039
            hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
1040

1041
        # 4. Feed-forward
1042
1043
        # i2vgen doesn't have this norm 🤷‍♂️
        if self.norm_type == "ada_norm_continuous":
Will Berman's avatar
Will Berman committed
1044
            norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
1045
        elif not self.norm_type == "ada_norm_single":
Sayak Paul's avatar
Sayak Paul committed
1046
            norm_hidden_states = self.norm3(hidden_states)
Kashif Rasul's avatar
Kashif Rasul committed
1047

1048
        if self.norm_type == "ada_norm_zero":
Kashif Rasul's avatar
Kashif Rasul committed
1049
1050
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

1051
        if self.norm_type == "ada_norm_single":
Sayak Paul's avatar
Sayak Paul committed
1052
1053
1054
            norm_hidden_states = self.norm2(hidden_states)
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp

1055
1056
        if self._chunk_size is not None:
            # "feed_forward_chunk_size" can be used to save memory
1057
            ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1058
        else:
1059
            ff_output = self.ff(norm_hidden_states)
Kashif Rasul's avatar
Kashif Rasul committed
1060

1061
        if self.norm_type == "ada_norm_zero":
Kashif Rasul's avatar
Kashif Rasul committed
1062
            ff_output = gate_mlp.unsqueeze(1) * ff_output
1063
        elif self.norm_type == "ada_norm_single":
Sayak Paul's avatar
Sayak Paul committed
1064
            ff_output = gate_mlp * ff_output
Kashif Rasul's avatar
Kashif Rasul committed
1065
1066

        hidden_states = ff_output + hidden_states
Sayak Paul's avatar
Sayak Paul committed
1067
1068
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)
Will Berman's avatar
Will Berman committed
1069

1070
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
1071
1072


1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
class LuminaFeedForward(nn.Module):
    r"""
    A feed-forward layer.

    Parameters:
        hidden_size (`int`):
            The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
            hidden representations.
        intermediate_size (`int`): The intermediate dimension of the feedforward layer.
        multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
            of this value.
        ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
            dimension. Defaults to None.
    """

    def __init__(
        self,
        dim: int,
        inner_dim: int,
        multiple_of: Optional[int] = 256,
        ffn_dim_multiplier: Optional[float] = None,
    ):
        super().__init__()
        # custom hidden_size factor multiplier
        if ffn_dim_multiplier is not None:
            inner_dim = int(ffn_dim_multiplier * inner_dim)
        inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)

        self.linear_1 = nn.Linear(
            dim,
            inner_dim,
            bias=False,
        )
        self.linear_2 = nn.Linear(
            inner_dim,
            dim,
            bias=False,
        )
        self.linear_3 = nn.Linear(
            dim,
            inner_dim,
            bias=False,
        )
        self.silu = FP32SiLU()

    def forward(self, x):
        return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))


Suraj Patil's avatar
Suraj Patil committed
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
    r"""
    A basic Transformer block for video like data.

    Parameters:
        dim (`int`): The number of channels in the input and output.
        time_mix_inner_dim (`int`): The number of channels for temporal attention.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
    """

    def __init__(
        self,
        dim: int,
        time_mix_inner_dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        cross_attention_dim: Optional[int] = None,
    ):
        super().__init__()
        self.is_res = dim == time_mix_inner_dim

        self.norm_in = nn.LayerNorm(dim)

        # Define 3 blocks. Each block has its own normalization layer.
        # 1. Self-Attn
        self.ff_in = FeedForward(
            dim,
            dim_out=time_mix_inner_dim,
            activation_fn="geglu",
        )

        self.norm1 = nn.LayerNorm(time_mix_inner_dim)
        self.attn1 = Attention(
            query_dim=time_mix_inner_dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            cross_attention_dim=None,
        )

        # 2. Cross-Attn
        if cross_attention_dim is not None:
            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
            # the second cross attention block.
            self.norm2 = nn.LayerNorm(time_mix_inner_dim)
            self.attn2 = Attention(
                query_dim=time_mix_inner_dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
            )  # is self-attn if encoder_hidden_states is none
        else:
            self.norm2 = None
            self.attn2 = None

        # 3. Feed-forward
        self.norm3 = nn.LayerNorm(time_mix_inner_dim)
        self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")

        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = None

    def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
        self._chunk_dim = 1

    def forward(
        self,
1196
        hidden_states: torch.Tensor,
Suraj Patil's avatar
Suraj Patil committed
1197
        num_frames: int,
1198
1199
        encoder_hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
Suraj Patil's avatar
Suraj Patil committed
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
        # Notice that normalization is always applied before the real computation in the following blocks.
        # 0. Self-Attention
        batch_size = hidden_states.shape[0]

        batch_frames, seq_length, channels = hidden_states.shape
        batch_size = batch_frames // num_frames

        hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
        hidden_states = hidden_states.permute(0, 2, 1, 3)
        hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)

        residual = hidden_states
        hidden_states = self.norm_in(hidden_states)

        if self._chunk_size is not None:
Dhruv Nair's avatar
Dhruv Nair committed
1215
            hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
Suraj Patil's avatar
Suraj Patil committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
        else:
            hidden_states = self.ff_in(hidden_states)

        if self.is_res:
            hidden_states = hidden_states + residual

        norm_hidden_states = self.norm1(hidden_states)
        attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
        hidden_states = attn_output + hidden_states

        # 3. Cross-Attention
        if self.attn2 is not None:
            norm_hidden_states = self.norm2(hidden_states)
            attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
            hidden_states = attn_output + hidden_states

        # 4. Feed-forward
        norm_hidden_states = self.norm3(hidden_states)

        if self._chunk_size is not None:
            ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
        else:
            ff_output = self.ff(norm_hidden_states)

        if self.is_res:
            hidden_states = ff_output + hidden_states
        else:
            hidden_states = ff_output

        hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
        hidden_states = hidden_states.permute(0, 2, 1, 3)
        hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)

        return hidden_states


Will Berman's avatar
Will Berman committed
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
class SkipFFTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        kv_input_dim: int,
        kv_input_dim_proj_use_bias: bool,
        dropout=0.0,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        attention_out_bias: bool = True,
    ):
        super().__init__()
        if kv_input_dim != dim:
            self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
        else:
            self.kv_mapper = None

        self.norm1 = RMSNorm(dim, 1e-06)

        self.attn1 = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            cross_attention_dim=cross_attention_dim,
            out_bias=attention_out_bias,
        )

        self.norm2 = RMSNorm(dim, 1e-06)

        self.attn2 = Attention(
            query_dim=dim,
            cross_attention_dim=cross_attention_dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            out_bias=attention_out_bias,
        )

    def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}

        if self.kv_mapper is not None:
            encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))

        norm_hidden_states = self.norm1(hidden_states)

        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            **cross_attention_kwargs,
        )

        hidden_states = attn_output + hidden_states

        norm_hidden_states = self.norm2(hidden_states)

        attn_output = self.attn2(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            **cross_attention_kwargs,
        )

        hidden_states = attn_output + hidden_states

        return hidden_states


Aryan's avatar
Aryan committed
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
@maybe_allow_in_graph
class FreeNoiseTransformerBlock(nn.Module):
    r"""
    A FreeNoise Transformer block.

    Parameters:
        dim (`int`):
            The number of channels in the input and output.
        num_attention_heads (`int`):
            The number of heads to use for multi-head attention.
        attention_head_dim (`int`):
            The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0):
            The dropout probability to use.
        cross_attention_dim (`int`, *optional*):
            The size of the encoder_hidden_states vector for cross attention.
        activation_fn (`str`, *optional*, defaults to `"geglu"`):
            Activation function to be used in feed-forward.
        num_embeds_ada_norm (`int`, *optional*):
            The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (`bool`, defaults to `False`):
            Configure if the attentions should contain a bias parameter.
        only_cross_attention (`bool`, defaults to `False`):
            Whether to use only cross-attention layers. In this case two cross attention layers are used.
        double_self_attention (`bool`, defaults to `False`):
            Whether to use two self-attention layers. In this case no cross attention layers are used.
        upcast_attention (`bool`, defaults to `False`):
            Whether to upcast the attention computation to float32. This is useful for mixed precision training.
        norm_elementwise_affine (`bool`, defaults to `True`):
            Whether to use learnable elementwise affine parameters for normalization.
        norm_type (`str`, defaults to `"layer_norm"`):
            The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
        final_dropout (`bool` defaults to `False`):
            Whether to apply a final dropout after the last feed-forward layer.
        attention_type (`str`, defaults to `"default"`):
            The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
        positional_embeddings (`str`, *optional*):
            The type of positional embeddings to apply to.
        num_positional_embeddings (`int`, *optional*, defaults to `None`):
            The maximum number of positional embeddings to apply.
        ff_inner_dim (`int`, *optional*):
            Hidden dimension of feed-forward MLP.
        ff_bias (`bool`, defaults to `True`):
            Whether or not to use bias in feed-forward MLP.
        attention_out_bias (`bool`, defaults to `True`):
            Whether or not to use bias in attention output project layer.
        context_length (`int`, defaults to `16`):
            The maximum number of frames that the FreeNoise block processes at once.
        context_stride (`int`, defaults to `4`):
            The number of frames to be skipped before starting to process a new batch of `context_length` frames.
        weighting_scheme (`str`, defaults to `"pyramid"`):
            The weighting scheme to use for weighting averaging of processed latent frames. As described in the
Quentin Gallouédec's avatar
Quentin Gallouédec committed
1376
1377
            Equation 9. of the [FreeNoise](https://huggingface.co/papers/2310.15169) paper, "pyramid" is the default
            setting used.
Aryan's avatar
Aryan committed
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
    """

    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        dropout: float = 0.0,
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
        only_cross_attention: bool = False,
        double_self_attention: bool = False,
        upcast_attention: bool = False,
        norm_elementwise_affine: bool = True,
        norm_type: str = "layer_norm",
        norm_eps: float = 1e-5,
        final_dropout: bool = False,
        positional_embeddings: Optional[str] = None,
        num_positional_embeddings: Optional[int] = None,
        ff_inner_dim: Optional[int] = None,
        ff_bias: bool = True,
        attention_out_bias: bool = True,
        context_length: int = 16,
        context_stride: int = 4,
        weighting_scheme: str = "pyramid",
    ):
        super().__init__()
        self.dim = dim
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        self.dropout = dropout
        self.cross_attention_dim = cross_attention_dim
        self.activation_fn = activation_fn
        self.attention_bias = attention_bias
        self.double_self_attention = double_self_attention
        self.norm_elementwise_affine = norm_elementwise_affine
        self.positional_embeddings = positional_embeddings
        self.num_positional_embeddings = num_positional_embeddings
        self.only_cross_attention = only_cross_attention

        self.set_free_noise_properties(context_length, context_stride, weighting_scheme)

        # We keep these boolean flags for backward-compatibility.
        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
        self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
        self.use_layer_norm = norm_type == "layer_norm"
        self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"

        if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
            raise ValueError(
                f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
                f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
            )

        self.norm_type = norm_type
        self.num_embeds_ada_norm = num_embeds_ada_norm

        if positional_embeddings and (num_positional_embeddings is None):
            raise ValueError(
                "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
            )

        if positional_embeddings == "sinusoidal":
            self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
        else:
            self.pos_embed = None

        # Define 3 blocks. Each block has its own normalization layer.
        # 1. Self-Attn
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

        self.attn1 = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
            upcast_attention=upcast_attention,
            out_bias=attention_out_bias,
        )

        # 2. Cross-Attn
        if cross_attention_dim is not None or double_self_attention:
            self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

            self.attn2 = Attention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
                out_bias=attention_out_bias,
            )  # is self-attn if encoder_hidden_states is none

        # 3. Feed-forward
        self.ff = FeedForward(
            dim,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
            inner_dim=ff_inner_dim,
            bias=ff_bias,
        )

        self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

    def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
        frame_indices = []
        for i in range(0, num_frames - self.context_length + 1, self.context_stride):
            window_start = i
            window_end = min(num_frames, i + self.context_length)
            frame_indices.append((window_start, window_end))
        return frame_indices

    def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
Aryan's avatar
Aryan committed
1503
1504
1505
1506
        if weighting_scheme == "flat":
            weights = [1.0] * num_frames

        elif weighting_scheme == "pyramid":
Aryan's avatar
Aryan committed
1507
1508
            if num_frames % 2 == 0:
                # num_frames = 4 => [1, 2, 2, 1]
Aryan's avatar
Aryan committed
1509
1510
                mid = num_frames // 2
                weights = list(range(1, mid + 1))
Aryan's avatar
Aryan committed
1511
1512
1513
                weights = weights + weights[::-1]
            else:
                # num_frames = 5 => [1, 2, 3, 2, 1]
Aryan's avatar
Aryan committed
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
                mid = (num_frames + 1) // 2
                weights = list(range(1, mid))
                weights = weights + [mid] + weights[::-1]

        elif weighting_scheme == "delayed_reverse_sawtooth":
            if num_frames % 2 == 0:
                # num_frames = 4 => [0.01, 2, 2, 1]
                mid = num_frames // 2
                weights = [0.01] * (mid - 1) + [mid]
                weights = weights + list(range(mid, 0, -1))
            else:
                # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
                mid = (num_frames + 1) // 2
                weights = [0.01] * mid
                weights = weights + list(range(mid, 0, -1))
Aryan's avatar
Aryan committed
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
        else:
            raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")

        return weights

    def set_free_noise_properties(
        self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
    ) -> None:
        self.context_length = context_length
        self.context_stride = context_stride
        self.weighting_scheme = weighting_scheme

    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        self._chunk_dim = dim

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        if cross_attention_kwargs is not None:
            if cross_attention_kwargs.get("scale", None) is not None:
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}

        # hidden_states: [B x H x W, F, C]
        device = hidden_states.device
        dtype = hidden_states.dtype

        num_frames = hidden_states.size(1)
        frame_indices = self._get_frame_indices(num_frames)
        frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
        frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
        is_last_frame_batch_complete = frame_indices[-1][1] == num_frames

        # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
        # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
        #    [(0, 16), (4, 20), (8, 24), (10, 26)]
        if not is_last_frame_batch_complete:
            if num_frames < self.context_length:
                raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
            last_frame_batch_length = num_frames - frame_indices[-1][1]
            frame_indices.append((num_frames - self.context_length, num_frames))

        num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
        accumulated_values = torch.zeros_like(hidden_states)

        for i, (frame_start, frame_end) in enumerate(frame_indices):
            # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
            # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
            # essentially a non-multiple of `context_length`.
            weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
            weights *= frame_weights

            hidden_states_chunk = hidden_states[:, frame_start:frame_end]

            # Notice that normalization is always applied before the real computation in the following blocks.
            # 1. Self-Attention
            norm_hidden_states = self.norm1(hidden_states_chunk)

            if self.pos_embed is not None:
                norm_hidden_states = self.pos_embed(norm_hidden_states)

            attn_output = self.attn1(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                attention_mask=attention_mask,
                **cross_attention_kwargs,
            )

            hidden_states_chunk = attn_output + hidden_states_chunk
            if hidden_states_chunk.ndim == 4:
                hidden_states_chunk = hidden_states_chunk.squeeze(1)

            # 2. Cross-Attention
            if self.attn2 is not None:
                norm_hidden_states = self.norm2(hidden_states_chunk)

                if self.pos_embed is not None and self.norm_type != "ada_norm_single":
                    norm_hidden_states = self.pos_embed(norm_hidden_states)

                attn_output = self.attn2(
                    norm_hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=encoder_attention_mask,
                    **cross_attention_kwargs,
                )
                hidden_states_chunk = attn_output + hidden_states_chunk

            if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
                accumulated_values[:, -last_frame_batch_length:] += (
                    hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
                )
                num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
            else:
                accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
                num_times_accumulated[:, frame_start:frame_end] += weights

1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
        # TODO(aryan): Maybe this could be done in a better way.
        #
        # Previously, this was:
        # hidden_states = torch.where(
        #    num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
        # )
        #
        # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
        # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
        # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
        # looked into this deeply because other memory optimizations led to more pronounced reductions.
        hidden_states = torch.cat(
            [
                torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
                for accumulated_split, num_times_split in zip(
                    accumulated_values.split(self.context_length, dim=1),
                    num_times_accumulated.split(self.context_length, dim=1),
                )
            ],
            dim=1,
Aryan's avatar
Aryan committed
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
        ).to(dtype)

        # 3. Feed-forward
        norm_hidden_states = self.norm3(hidden_states)

        if self._chunk_size is not None:
            ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
        else:
            ff_output = self.ff(norm_hidden_states)

        hidden_states = ff_output + hidden_states
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)

        return hidden_states


Patrick von Platen's avatar
Patrick von Platen committed
1672
class FeedForward(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
1673
1674
1675
1676
    r"""
    A feed-forward layer.

    Parameters:
Will Berman's avatar
Will Berman committed
1677
1678
1679
1680
1681
        dim (`int`): The number of channels in the input.
        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
Kashif Rasul's avatar
Kashif Rasul committed
1682
        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1683
        bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
Kashif Rasul's avatar
Kashif Rasul committed
1684
1685
1686
    """

    def __init__(
Will Berman's avatar
Will Berman committed
1687
1688
1689
1690
1691
1692
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
Kashif Rasul's avatar
Kashif Rasul committed
1693
        final_dropout: bool = False,
Will Berman's avatar
Will Berman committed
1694
        inner_dim=None,
1695
        bias: bool = True,
Kashif Rasul's avatar
Kashif Rasul committed
1696
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1697
        super().__init__()
Will Berman's avatar
Will Berman committed
1698
1699
        if inner_dim is None:
            inner_dim = int(dim * mult)
1700
        dim_out = dim_out if dim_out is not None else dim
Patrick von Platen's avatar
Patrick von Platen committed
1701

1702
        if activation_fn == "gelu":
1703
            act_fn = GELU(dim, inner_dim, bias=bias)
Kashif Rasul's avatar
Kashif Rasul committed
1704
        if activation_fn == "gelu-approximate":
1705
            act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1706
        elif activation_fn == "geglu":
1707
            act_fn = GEGLU(dim, inner_dim, bias=bias)
Will Berman's avatar
Will Berman committed
1708
        elif activation_fn == "geglu-approximate":
1709
            act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1710
1711
        elif activation_fn == "swiglu":
            act_fn = SwiGLU(dim, inner_dim, bias=bias)
Aryan's avatar
Aryan committed
1712
1713
        elif activation_fn == "linear-silu":
            act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
Will Berman's avatar
Will Berman committed
1714
1715

        self.net = nn.ModuleList([])
1716
        # project in
1717
        self.net.append(act_fn)
1718
1719
1720
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
1721
        self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
Kashif Rasul's avatar
Kashif Rasul committed
1722
1723
1724
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))
Patrick von Platen's avatar
Patrick von Platen committed
1725

1726
1727
1728
1729
    def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            deprecate("scale", "1.0.0", deprecation_message)
1730
        for module in self.net:
1731
            hidden_states = module(hidden_states)
1732
        return hidden_states