attention.py 282 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
import math
10
import os
11
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
import warnings
13
import logging
14

15
from dataclasses import dataclass, fields
cyanguwa's avatar
cyanguwa committed
16
import numpy as np
17
from packaging.version import Version as PkgVersion
18
19

import torch
20
import torch.nn.functional as F
21

22
import transformer_engine_torch as tex
23
24
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
25
26
27
28
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
29
30
31
32
33
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
34
35
    fused_attn_fwd,
    fused_attn_bwd,
36
37
38
39
40
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
41
42
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
43
from transformer_engine.pytorch.module import LayerNormLinear, Linear
44
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
45
46
47
48
49
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
50
    get_default_init_method,
51
52
53
54
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
55
    AttnBiasTypes,
56
    QKVLayouts,
57
    dist_group_type,
58
    TE_DType,
59
60
61
62
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
63
    get_distributed_rank,
64
    checkpoint,
65
66
67
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
68
69
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
70
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
71
72
from transformer_engine.pytorch.graph import is_graph_capturing

73

74
75
76
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
77
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
78
79
80
81
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
82

83
if _flash_attn_version >= _flash_attn_version_required:
84
85
86
87
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
88

89
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
90
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
91
92
93
94
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
95

96
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
97
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
98
99
100
101
102
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
103
104
    format="[%(levelname)-8s | %(name)-19s]: %(message)s",
    level=log_levels[log_level if log_level in [0, 1, 2] else 2],
105
106
)

107
108
109
110
111
112
113
114
115
116
117
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))

_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
118
}
119
120


121
122
@dataclass(eq=True)
class AttentionParams:
123
    """
124
    Attention parameters used to determine which backend to be used.
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    Parameters
    ----------
    qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
        Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
    qkv_dtype: torch.dtype, default = `torch.bfloat16`
        Data type of query/key/value tensors.
    qkv_layout: str, default = "sbh3d"
        Query/key/value tensor memory layout.
    batch_size: int, default = 1
        Batch size.
    num_heads: int, default = 16
        Number of attention heads in the query tensor.
    num_gqa_groups: int, default = 16
        Number of attention heads in key and value tensors.
    max_seqlen_q: int, default = 128
        Maximum sequence length of the query tensor.
    max_seqlen_kv: int, default = 128
        Maximum sequence length of the key and value tensors.
    head_dim: int, default = 64
        The size of each attention head.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
        `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
149
    window_size: Tuple[int, int], default = None
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        Sliding window attention size.
    alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
        Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
    core_attention_bias_type: str, default = `no_bias`
        Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
    core_attention_bias_shape: str, default = `1hss`
        Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
    core_attention_bias_requires_grad: bool, default = `True`
        Whether attention bias requires gradient.
    pad_between_seqs: bool, default = `False`
        Whether there is padding between sequences in a batch.
        This only applies to `qkv_format=thd`.
    attention_dropout: float, default = 0.0
        Attention dropout.
    context_parallel: bool, default = `False`
        Whether context parallelism is used or not.
    deterministic: bool, default = `False`
        Whether to run `DotProductAttention` with determinism or not.
168
169
    is_training: bool, default = `True`
        Whether in training mode (`True`) or inference mode (`False`)
170
171
172
173
    fp8: bool, default = `False`
        Whether `DotProductAttention` is in an `fp8_autocast` region.
    fp8_meta: Optional[Dict[str Any]], default = `None`
        The FP8 metadata tensor of `DotProductAttention`.
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    """

    qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
    qkv_dtype: torch.dtype = torch.bfloat16
    qkv_layout: str = "sbh3d"
    batch_size: int = 1
    num_heads: int = 16
    num_gqa_groups: int = 16
    max_seqlen_q: int = 128
    max_seqlen_kv: int = 128
    head_dim: int = 64
    attn_mask_type: str = "no_mask"
    window_size: Union[Tuple[int, int], None] = None
    alibi_slopes_shape: Union[torch.Size, List, None] = None
    core_attention_bias_type: str = "no_bias"
    core_attention_bias_shape: str = "1hss"
    core_attention_bias_requires_grad: bool = True
    pad_between_seqs: bool = False
    attention_dropout: float = 0.0
    context_parallel: bool = False
    deterministic: bool = False
    is_training: bool = True
    fp8: bool = False
    fp8_meta: Union[Dict[str, Any], None] = None


_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


def get_attention_backend(
    attention_params: AttentionParams = None,
):
    """
    Select the appropriate attention backend/sub-backend based on user input and runtime environment.

    Parameters
    ----------
    See `AttentionParams`.
224
225
226
227
228
229
230

    Returns
    ----------
    use_flash_attention: bool
        Whether the `FlashAttention` backend has been selected.
    use_fused_attention: bool
        Whether the `FusedAttention` backend has been selected.
231
232
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend
        If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
233
234
235
236
237
238
    use_unfused_attention: bool
        Whether the `UnfusedDotProductAttention` backend has been selected.
    available_backends: List[bool]
        All available backends that could support the provided input. A list of Booleans
        in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
    """
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    qkv_type = attention_params.qkv_type
    qkv_dtype = attention_params.qkv_dtype
    qkv_layout = attention_params.qkv_layout
    batch_size = attention_params.batch_size
    num_heads = attention_params.num_heads
    num_gqa_groups = attention_params.num_gqa_groups
    max_seqlen_q = attention_params.max_seqlen_q
    max_seqlen_kv = attention_params.max_seqlen_kv
    head_dim = attention_params.head_dim
    attn_mask_type = attention_params.attn_mask_type
    window_size = attention_params.window_size
    alibi_slopes_shape = attention_params.alibi_slopes_shape
    core_attention_bias_type = attention_params.core_attention_bias_type
    core_attention_bias_shape = attention_params.core_attention_bias_shape
    core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
    pad_between_seqs = attention_params.pad_between_seqs
    attention_dropout = attention_params.attention_dropout
    context_parallel = attention_params.context_parallel
    deterministic = attention_params.deterministic
    is_training = attention_params.is_training
    fp8 = attention_params.fp8
    fp8_meta = attention_params.fp8_meta

    # Run config
263
    logger = logging.getLogger("DotProductAttention")
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    device_compute_capability = get_device_compute_capability()
    cudnn_version = get_cudnn_version()
    run_config = {
        "transformer_engine_version": te.__version__,
        "compute_capability": "sm"
        + str(
            (lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
        ),
        "flash_attn_version": _flash_attn_version,
        "cudnn_version": ".".join([str(i) for i in cudnn_version]),
    }
    attention_params_dict = {
        field.name: getattr(attention_params, field.name) for field in fields(attention_params)
    }
    run_config.update(attention_params_dict)
    if fp8:
        run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
    logger.debug("Running with config=%s", run_config)
282
283

    # Filter: Environment variables
284
285
286
287
288
289
290
    global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN
    _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
    _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
    _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
    use_flash_attention = _NVTE_FLASH_ATTN
    use_fused_attention = _NVTE_FUSED_ATTN
    use_unfused_attention = _NVTE_UNFUSED_ATTN
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    if not use_flash_attention:
        logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
    if not use_fused_attention:
        logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
    if not use_unfused_attention:
        logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")

    # Filter: ONNX mode
    if is_in_onnx_export_mode():
        if use_flash_attention:
            logger.debug("Disabling FlashAttention due to ONNX mode")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention due to ONNX mode")
        use_fused_attention = False

    # Filter: Compute capability
    if device_compute_capability < (8, 0):
        if use_flash_attention:
            logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
            use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
            use_fused_attention = False

    # Filter: Context parallelism
    if context_parallel and use_unfused_attention:
        logger.debug(
            "Disabling UnfusedDotProductAttention as it does not support context parallelism"
        )
        use_unfused_attention = False

    # Filter: Data type
    if use_flash_attention and (
        qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor
    ):
        logger.debug(
            "Disabling FlashAttention due to unsupported QKV data type. "
            "Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. "
            "Found: qkv_type = %s, qkv_dtype = %s.",
            qkv_type,
            qkv_dtype,
        )
        use_flash_attention = False
    if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]):
        logger.debug(
            "Disabling FusedAttention due to unsupported QKV data type. "
            "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
            "Found: qkv_dtype = %s.",
            qkv_dtype,
        )
        use_fused_attention = False

    # Filter: Execution type
    if fp8 and fp8_meta["recipe"].fp8_dpa:
        if use_flash_attention:
            logger.debug("Disabling FlashAttention as it does not support FP8")
            use_flash_attention = False
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
            use_unfused_attention = False

    # Filter: Head dimension
    if use_flash_attention and (
        head_dim > 256
        or head_dim % 8 != 0
        or (head_dim > 192 and device_compute_capability not in ((8, 0), (9, 0)))
    ):
        logger.debug(
            "Disabling FlashAttention due to unsupported head_dim. "
            "Supported: head_dim %%8 = 0, head_dim <= 256 (>192 requires sm80/90). "
            "Found: head_dim = %s on sm%s.",
            head_dim,
            ".".join([str(i) for i in device_compute_capability]),
        )
        use_flash_attention = False

    # Filter: QKV layout
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
            use_unfused_attention = False
        if use_flash_attention and pad_between_seqs:
            logger.debug(
                "Disabling FlashAttention for qkv_format = thd when there is "
                "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
            )
            use_flash_attention = False

    # Filter: Attention mask
    # attn_mask_type               |     supported backends
    # -------------------------------------------------------------------
    # no_mask                      |     All
    # padding                      |     FlashAttention, FusedAttention
    # causal                       |
    #     self-attention           |     All
    #     cross-attention          |     FusedAttention
    # padding_causal               |
    #     self-attention           |     FlashAttention, FusedAttention
    #     cross-attention          |     FusedAttention
    # causal_bottom_right          |     All
    # padding_causal_bottom_right  |     FlashAttention, FusedAttention
    # arbitrary                    |     UnfusedDotProductAttention
    if attn_mask_type == "arbitrary":
        if use_flash_attention:
            logger.debug("Disabling FlashAttention for arbitrary mask")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention for arbitrary mask")
        use_fused_attention = False
    if use_unfused_attention and "padding" in attn_mask_type:
        logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
        use_unfused_attention = False
    if (
        use_flash_attention
        and _flash_attn_2_1_plus
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        use_flash_attention = False
    if (
        use_flash_attention
        and not _flash_attn_2_1_plus
        and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention as it only supports top-left-diagonal "
            "causal mask before flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        use_flash_attention = False

    # Filter: Sliding window attention
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    #    backend                 |      window_size       | diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | (-1, -1) or (>=0, >=0) | bottom right
    # FusedAttention             | (-1,  0) or (>=0, 0)   | top left
    # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
    #                            |                        | converts window_size to an 'arbitrary' mask
    if window_size is None:
        window_size = check_set_window_size(attn_mask_type, window_size)
    else:
        if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention"
                    " for FP8"
                )
                use_fused_attention = False
            elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
                logger.debug(
                    "Disabling FusedAttention as it only supports sliding window attention "
                    "with causal mask, no dropout, and qkv_format = bshd/sbhd"
                )
                use_fused_attention = False
            elif context_parallel:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with context parallelism"
                )
                use_fused_attention = False
            elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
                "no_mask",
                "padding",
                "causal_bottom_right",
                "padding_causal_bottom_right",
            ]:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s for cross-attention",
                    attn_mask_type,
                )
                use_fused_attention = False
            elif "padding" in attn_mask_type:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s",
                    attn_mask_type,
                )
                use_fused_attention = False
        if (
            use_flash_attention
            and (window_size[0] != -1 or window_size[1] not in [-1, 0])
            and (not _flash_attn_2_3_plus or context_parallel)
        ):
483
484
485
486
487
488
489
            logger.debug(
                "Disabling FlashAttention as sliding window attention requires "
                "flash-attn 2.3+ and no context parallelism"
            )
            use_flash_attention = False

    # Filter: Attention bias
490
491
492
493
494
495
496
497
    #    backend                 |      bias types              | ALiBi diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | no_bias, alibi/alibi_slopes  | bottom right
    # FusedAttention             | no_bias, post_scale_bias     |
    #                            | alibi/alibi_slopes           | top left,
    #                            |                              | bottom_right (converts to a 'post_scale_bias' bias)
    # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
    #                            | alibi/alibi_slopes           | both; converts to a 'post_scale_bias' bias
498
499
500
501
502
503
504
505
506
507
508
509
510
    if use_flash_attention and (
        core_attention_bias_type not in ["no_bias", "alibi"]
        or core_attention_bias_shape is not None
    ):
        logger.debug("Disabling FlashAttention for pre/post_scale_bias")
        use_flash_attention = False

    fu_core_attention_bias_type = core_attention_bias_type
    fu_core_attention_bias_shape = core_attention_bias_shape
    fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
    if (
        use_fused_attention
        and core_attention_bias_type == "alibi"
511
        and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
512
513
514
    ):
        fu_core_attention_bias_type = "post_scale_bias"
        fu_core_attention_bias_requires_grad = False
515
516
517
518
519
        if alibi_slopes_shape is None:
            fu_core_attention_bias_shape = "1hss"
        elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
            fu_core_attention_bias_shape = "1hss"
        elif (
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
            len(alibi_slopes_shape) == 2
            and alibi_slopes_shape[0] == batch_size
            and alibi_slopes_shape[1] == num_heads
        ):
            fu_core_attention_bias_shape = "bhss"

    if (
        use_fused_attention
        and fu_core_attention_bias_type == "post_scale_bias"
        and fu_core_attention_bias_shape != "1hss"
    ):
        if fu_core_attention_bias_requires_grad:
            # remove this line when cuDNN adds bwd support for
            # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
            logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
            use_fused_attention = False
        else:
            # max512 backend will only support [1, h, s, s]
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

    # Filter: cuDNN support
    fused_attention_backend = None
    if use_fused_attention:
        q_type = TE_DType[qkv_dtype]
        kv_type = q_type
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            kv_type = q_type
        fused_attention_backend = tex.get_fused_attn_backend(
            q_type,
            kv_type,
            QKVLayout[qkv_layout],
            AttnBiasType[fu_core_attention_bias_type],
            AttnMaskType[attn_mask_type],
            attention_dropout,
            num_heads,
            num_gqa_groups,
            max_seqlen_q,
            max_seqlen_kv,
            head_dim,
560
561
            window_size[0],
            window_size[1],
562
        )
563
        if fused_attention_backend == FusedAttnBackend["No_Backend"]:
564
565
            logger.debug("Disabling FusedAttention as no backend supports the provided input")
            use_fused_attention = False
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
            fused_attention_backend = None
        if (
            use_fused_attention
            and context_parallel
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "context parallellism",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and window_size is not None
            and window_size[0] != -1
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "slidng window attention",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
595
596
597
598
599
600
601
602
            and fu_core_attention_bias_type == "post_scale_bias"
            and fu_core_attention_bias_shape != "1hss"
        ):
            logger.debug(
                "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
                " [1, H, S, S] shape"
            )
            use_fused_attention = False
603
            fused_attention_backend = None
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623

    # Filter: Determinism
    # backend                      | deterministic
    # ---------------------------------------------
    # FlashAttention               |
    #     flash-attn >=2.0, <2.4.1 | no
    #     flash-attn >=2.4.1       | yes
    # FusedAttention               |
    #     sub-backend 0            | yes
    #     sub-backend 1            | workspace optimization path and sm90+: yes;
    #                              | otherwise: no
    #     sub-backend 2            | no
    # UnfusedDotProductAttention   | yes
    if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus:
        logger.warning(
            "Disabling FlashAttention as version <2.4.1 does not support deterministic "
            "execution. To use FlashAttention with deterministic behavior, "
            "please install flash-attn >= 2.4.1."
        )
        use_flash_attention = False
624
625
626
627
628
629
630
631
632
633
634
    if use_fused_attention and deterministic:
        if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (
                device_compute_capability < (9, 0)
                or core_attention_bias_requires_grad
                or cudnn_version < (8, 9, 5)
635
            )
636
637
638
        ):
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
639
640
641

    # All available backends
    available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
642
643
644
645
646
647
648
649
650
651
652
653
    logger.debug(
        "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
        " UnfusedDotProductAttention=%s}",
        bool(available_backends[0]),
        bool(available_backends[1]),
        (
            f" (sub-backend {int(fused_attention_backend)})"
            if fused_attention_backend is not None
            else ""
        ),
        bool(available_backends[2]),
    )
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673

    # Select FusedAttention for performance
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
    ):
        if device_compute_capability == (9, 0):
            logger.debug(
                "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                "for performance reasons"
            )
            use_flash_attention = False

    # Selected backend
    if use_flash_attention:
        use_fused_attention = False
        use_unfused_attention = False
    elif use_fused_attention:
        use_unfused_attention = False
674
    selected_backend = "NoBackend"
675
676
677
678
679
680
    if use_flash_attention:
        selected_backend = "FlashAttention"
    elif use_fused_attention:
        selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
    elif use_unfused_attention:
        selected_backend = "UnfusedDotProductAttention"
681
    logger.debug("Selected backend = %s", selected_backend)
682

683
684
685
686
687
688
    global _attention_backends
    _attention_backends["use_flash_attention"] = use_flash_attention
    _attention_backends["use_fused_attention"] = use_fused_attention
    _attention_backends["fused_attention_backend"] = fused_attention_backend
    _attention_backends["use_unfused_attention"] = use_unfused_attention
    _attention_backends["backend_selection_requires_update"] = False
689
690
691
692

    return (
        use_flash_attention,
        use_fused_attention,
693
        fused_attention_backend,
694
695
696
697
698
        use_unfused_attention,
        available_backends,
    )


699
class InferenceParams:  # pylint: disable=too-few-public-methods
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
    """
    Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference.

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
743

744

745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
@torch.no_grad()
def get_swa_mask(
    window_size: Tuple[int, int],
    max_seqlen_q: int,
    max_seqlen_kv: int,
    attn_mask_type: str = "no_mask",
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
    """
    Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
    For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
    and for other mask types, the bottom right corner.

    Parameters
    ----------
    window_size: Tuple[int, int]
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
        window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
        map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
        `attn_mask_type`.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
        "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
        default = `None`
        Boolean tensor(s) used to mask out attention softmax input.

    Returns
    ----------
    attention_mask: torch.Tensor
        Combined `attention_mask` (input) and sliding window attention mask.
        The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
        else, the same shape as input `attention_mask`.
    """
    mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
    if attn_mask_type in ["causal"]:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_q
        right = window_size[1] if window_size[1] != -1 else max_seqlen_q
        mask_upper = torch.triu(mask, diagonal=-left)
        mask_lower = torch.tril(mask_upper, diagonal=right)
    else:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
        right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
        mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
        mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
    attn_mask_type = "arbitrary"
    mask = mask_lower.logical_not()
    if attention_mask is not None:
        mask = torch.logical_and(attention_mask, mask)
    return attn_mask_type, mask


803
804
805
806
807
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
808
809
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
810
    bottom_right_alignment: bool = True,
811
) -> Tuple[torch.Tensor, torch.Tensor]:
812
    """
813
814
815
816
817
818
819
820
821
822
823
824
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
825
826
827
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the ALiBi bias to the bottom right corner of
        the matrix (`True`) or top left (`False`).
828

829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
        ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
        then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
        `alibi_slopes` is in [batch_size, num_heads], then the bias is in
        [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
        if _alibi_cache["_alibi_slopes"].dim() == 2:
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
863
864
865
866
867
868
869
870
        if bottom_right_alignment:
            bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(
                1, 1, 1, max_seqlen_kv
            )
        else:
            bias = torch.arange(
                1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda"
            ).view(1, 1, 1, max_seqlen_kv)
871
872
873
        bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(
            1, 1, max_seqlen_q, 1
        )
874
875
876
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
877
        _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
878
879
880
881
882
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
883
884
885
886
887
888
889
890
891


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
892
    reduced_mask = mask.logical_not().sum(dim=1)
893
894
895
896
897
898
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

899

900
901
902
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
903
904
905
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
906
907
908
909
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

910
    reduced_mask = mask.logical_not().sum(dim=1)
911
912
913
914
915
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
916
    indices = mask.logical_not().nonzero()
917
918
919
920
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
921
922
923
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
924
925
926
927

    return cu_seqlens, indices


928
929
930
931
932
933
934
935
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
936
937
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
938
939
940

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
941
942
943
944
945
946
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
947
948
949

    return indices

950

951
_cu_seqlens_cache = {}
952
953


954
955
956
957
958
959
960
961
962
963
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
964
965
966
967
968
969
970
971
972
973
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
974
975


976
977
978
979
980
981
982
983
984
@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
985
986
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
1035
1036
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1037
    unpacked.scatter_(0, indices, tensor)
1038
    unpacked = unpacked[0:-1, :, :]
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
1078

1079
1080
    @staticmethod
    def forward(
1081
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
1082
1083
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
1084
        ctx.save_for_backward(indices)
1085
1086
1087
1088
1089
1090
1091
1092
1093
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
1094
        (indices,) = ctx.saved_tensors
1095
        if len(grad_outputs) == 1:
1096
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
1097
        if len(grad_outputs) == 2:
1098
1099
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
1100
1101
1102
1103
1104
1105


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
1106

1107
1108
1109
1110
1111
1112
1113
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
1114
        ctx.save_for_backward(indices)
1115
1116
1117
1118
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
1119
1120
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
1121
1122


1123
1124
1125
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
1126
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
1127
1128
1129
1130
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
1131
1132
1133
1134
1135
1136
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
1137
1138
1139
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
1140
1141
1142
1143
1144
1145
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


1165
@jit_fuser
1166
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
1167
    """Merge partial outputs of each step in Attention with context parallelism"""
1168
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
1169
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
1170
    out_corrected = out_per_step * softmax_lse_corrected_exp
1171
1172
1173
    out.add_(out_corrected)


1174
@jit_fuser
1175
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
1176
    """Merge softmax stats of each step in Attention with context parallelism"""
1177
1178
1179
1180
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
1181
1182


1183
class AttnFuncWithCP(torch.autograd.Function):
1184
    """
1185
1186
    Attention implementation with context parallelism.
    Split attention compute into multiple steps, and overlap current-step
1187
1188
1189
1190
    compute with next-step communication.
    """

    @staticmethod
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
1201
1202
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
        dropout_p,
        cp_group,
        cp_global_ranks,
        cp_stream,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ):
1215
1216
1217
1218
1219
1220
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
1221
        recv_src = cp_global_ranks[(rank - 1) % cp_size]
1222
1223
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1224
1225
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
1226

1227
1228
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

1229
        if causal:
1230
1231
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
1232
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
1233
1234
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
1235
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
1236
        if attn_bias is not None:
1237
            assert len(attn_bias.shape) == 4, (
1238
1239
1240
1241
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
1242
1243
1244
1245
1246
1247
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
1248
1249
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
1250
1251
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
1252
            )
1253
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
1254
1255
1256
1257
1258
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
1259

1260
1261
1262
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
1263
        attn_bias_inputs = [None, None]
1264
1265
1266
1267
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
1268
        attn_biases = [None for _ in range(cp_size)]
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
        p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
        send_recv_reqs = [[], []]

1279
        for i in range(cp_size + 1):
1280
            if i < cp_size:
1281
                with torch.cuda.stream(flash_attn_streams[i % 2]):
1282
                    # wait until KV is received
1283
                    for req in send_recv_reqs[(i + 1) % 2]:
1284
1285
                        req.wait()

1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

                    kv_inputs[i % 2] = p2p_comm_buffers[i]
1299
1300
                    if causal:
                        if i == 0:
1301
                            if use_fused_attention:
1302
1303
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1304
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1305
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
1306
1307
1308
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        2, k.shape[0], -1, *k.shape[-2:]
                                    )
1309
1310
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1311
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1312
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
1313
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:])
1314
                                elif qkv_format == "thd":
1315
                                    q_inputs[i % 2] = q
1316
1317
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1318
1319
1320
1321
1322
1323
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1324
                                    ).contiguous()
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q,
                                        max_seqlen_k,
                                        cu_seqlens_q,
                                        cu_seqlens_k,
                                        q_inputs[i % 2],
                                        kv_inputs[i % 2][0],
                                        kv_inputs[i % 2][1],
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type=attn_mask_type,
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
1343
1344
                                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1345
                                    )
1346
                                )
1347
1348
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
1349
1350
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1351
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1352
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
                                    cu_seqlens_q,
                                    cu_seqlens_k,
                                    max_seqlen_q,
                                    max_seqlen_k,
                                    dropout_p,
                                    softmax_scale,
                                    causal=True,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1376
                                )
1377
                        elif i <= rank:
1378
                            if use_fused_attention:
1379
1380
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1381
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1382
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
1383
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
1384
1385
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1386
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1387
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
1388
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
1389
                                elif qkv_format == "thd":
1390
                                    q_inputs[i % 2] = q
1391
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
1392
1393
1394
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                        kv_inputs[i % 2], cu_seqlens_k, 0
                                    )
1395
1396
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q,
                                        max_seqlen_k // 2,
                                        cu_seqlens_q,
                                        cu_seqlens_k // 2,
                                        q_inputs[i % 2],
                                        kv_inputs[i % 2][0],
                                        kv_inputs[i % 2][1],
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type="padding" if padding else "no_mask",
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
1416
1417
1418
1419
1420
                                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                                        cu_seqlens_kv_padded=(
                                            None
                                            if cu_seqlens_kv_padded is None
                                            else cu_seqlens_kv_padded // 2
1421
1422
                                        ),
                                    )
1423
                                )
1424
1425
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
1426
1427
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1428
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1429
1430
                                if qkv_format == "thd":
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
1431
1432
1433
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                        kv_inputs[i % 2], cu_seqlens_k, 0
                                    )
1434
1435
                                else:
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
1436
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
1437
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
1438
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
1439
1440
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
                                    cu_seqlens_q,
                                    cu_seqlens_k // 2,
                                    max_seqlen_q,
                                    max_seqlen_k // 2,
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1463
1464
1465
                                )
                        else:
                            if use_fused_attention:
1466
1467
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
1468
                                    q_inputs[i % 2] = q[:, 1, ...].contiguous()
1469
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
1470
1471
1472
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        2, k.shape[0], -1, *k.shape[-2:]
                                    )
1473
1474
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
1475
                                    q_inputs[i % 2] = q[1].contiguous()
1476
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
1477
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:])
1478
1479
                                elif qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
1480
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
1481
1482
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1483
1484
1485
1486
1487
1488
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1489
                                    ).contiguous()
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q // 2,
                                        max_seqlen_k,
                                        cu_seqlens_q // 2,
                                        cu_seqlens_k,
                                        q_inputs[i % 2],
                                        kv_inputs[i % 2][0],
                                        kv_inputs[i % 2][1],
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type="padding" if padding else "no_mask",
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
1508
1509
1510
1511
                                        cu_seqlens_q_padded=(
                                            None
                                            if cu_seqlens_q_padded is None
                                            else cu_seqlens_q_padded // 2
1512
                                        ),
1513
                                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1514
                                    )
1515
                                )
1516
1517
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
1518
                            else:
1519
1520
                                if qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
1521
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
1522
1523
                                else:
                                    # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
1524
                                    q_inputs[i % 2] = (
1525
                                        q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
1526
                                    )
1527
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
1528
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
1529
1530
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
                                    cu_seqlens_q // 2,
                                    cu_seqlens_k,
                                    max_seqlen_q // 2,
                                    max_seqlen_k,
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1553
1554
1555
                                )
                    else:
                        if use_fused_attention:
1556
1557
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
1558
1559
1560
1561
1562
1563
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
1564
                                ).contiguous()
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
                            out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_k,
                                    cu_seqlens_q,
                                    cu_seqlens_k,
                                    q,
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
                                    TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
1583
1584
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1585
                                )
1586
                            )
1587
1588
                            if len(rest) > 0:
                                attn_biases[i] = rest[0]
1589
                        else:
1590
                            # [b, sq, np, hn] -> [b*sq, np, hn]
1591
                            q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1592
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
                            kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                            (
                                _,
                                _,
                                _,
                                _,
                                out_per_step[i],
                                softmax_lse_per_step[i],
                                _,
                                rng_states[i],
                            ) = _flash_attn_forward(
                                q_inputs[i % 2],
                                kv_inputs[i % 2][0],
                                kv_inputs[i % 2][1],
                                cu_seqlens_q,
                                cu_seqlens_k,
                                max_seqlen_q,
                                max_seqlen_k,
                                dropout_p,
                                softmax_scale,
                                causal=False,
                                return_softmax=False,
                                **fa_optional_forward_kwargs,
1616
                            )
1617
1618
1619
1620

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
1621
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
1622

1623
1624
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
1625
                    softmax_lse_per_step[i - 1].squeeze_(-1)
1626

1627
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
1628
1629
1630
                    if i == 1:
                        out = torch.empty_like(q).zero_()
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
1631
                        if causal and qkv_format != "thd":
1632
1633
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
1634
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
1635
                            )
1636
1637
1638
1639
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
1640
                    else:
1641
                        if qkv_format == "thd":
1642
1643
1644
                            tex.thd_second_half_lse_correction(
                                softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q, q.size(0)
                            )
1645
                        else:
1646
1647
1648
                            flash_attn_fwd_softmax_lse_correction(
                                softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
                            )
1649
1650

                if i < cp_size:
1651
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
1652
1653
1654
1655

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
1656
1657
        if qkv_format in ["bshd", "sbhd"]:
            seq_dim = qkv_format.index("s")
1658
        for i in range(cp_size):
1659
1660
1661
1662
1663
1664
            if qkv_format == "bshd":
                out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
                out_ = out[:, 1, ...]
            elif qkv_format == "sbhd":
                out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
                out_ = out[1]
1665

1666
            if i <= rank or not causal:
1667
                if qkv_format in ["bshd", "sbhd"]:
1668
1669
1670
1671
1672
1673
1674
                    flash_attn_fwd_out_correction(
                        out.view(*out_per_step[i].shape),
                        out_per_step[i],
                        seq_dim,
                        softmax_lse,
                        softmax_lse_per_step[i],
                    )
1675
                elif qkv_format == "thd":
1676
1677
1678
1679
1680
1681
1682
1683
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
                        cu_seqlens_q,
                        False,
                    )
1684
1685
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
1686
            else:
1687
                if qkv_format in ["bshd", "sbhd"]:
1688
1689
1690
1691
1692
1693
1694
                    flash_attn_fwd_out_correction(
                        out_,
                        out_per_step[i],
                        seq_dim,
                        softmax_lse_[..., 1, :],
                        softmax_lse_per_step[i],
                    )
1695
                elif qkv_format == "thd":
1696
1697
1698
1699
1700
1701
1702
1703
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
                        cu_seqlens_q,
                        True,
                    )
1704
1705
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
1706
1707

        kv = p2p_comm_buffers[-1]
1708
        if use_fused_attention:
1709
1710
1711
1712
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
1713
1714
        else:
            out = out.view(-1, *out.shape[-2:])
1715

1716
        ctx.save_for_backward(
1717
1718
1719
1720
1721
1722
            q,
            kv,
            out,
            softmax_lse,
            cu_seqlens_q,
            cu_seqlens_k,
1723
1724
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
1725
1726
            *rng_states,
            *attn_biases,
1727
        )
1728
1729
1730
1731
1732
1733
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.softmax_scale = softmax_scale
1734
        ctx.qkv_format = qkv_format
1735
        ctx.attn_mask_type = attn_mask_type
1736
1737
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
1738
        ctx.deterministic = deterministic
1739
        ctx.use_fused_attention = use_fused_attention
1740
1741
1742
1743
        return out

    @staticmethod
    def backward(ctx, dout):
1744
        (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
1745
        (cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[6:8]
1746
        cp_size = get_distributed_world_size(ctx.cp_group)
1747
1748
        rng_states = ctx.saved_tensors[8 : 8 + cp_size]
        attn_biases = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
1749

1750
        rank = get_distributed_rank(ctx.cp_group)
1751
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
1752
1753
1754
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1755
1756
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
1757
1758
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

1759
        if attn_biases[0] is not None:
1760
1761
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
1762
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
1763
1764
1765
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
1766
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
1767
1768
1769
1770
            )
        else:
            attn_dbias = None

1771
        if causal:
1772
1773
1774
1775
            if ctx.qkv_format == "thd":
                softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
1776
1777
1778
                softmax_lse_ = softmax_lse.view(
                    *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
                )
1779
1780
1781
1782
1783
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
                if ctx.use_fused_attention:
                    # [b, np, sq//2] -> [b, np, sq//2, 1]
                    softmax_lse_.unsqueeze_(-1)

1784
1785
1786
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
1787
1788
1789
1790
1791
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        # Flash Attn outputs
        dq = torch.empty_like(q)

1792
1793
1794
1795
        p2p_comm_buffers = [
            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
        ]
1796
1797
1798
        p2p_comm_buffers[0][0].copy_(kv)
        send_recv_reqs = []

1799
1800
1801
1802
1803
1804
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

1805
1806
1807
1808
1809
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

1810
1811
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
1812
1813
1814
            if i == 0:
                send_tensor = send_tensor[0]
                recv_tensor = recv_tensor[0]
1815
            if i == (cp_size - 1):
1816
1817
1818
                send_tensor = send_tensor[1]
                recv_tensor = recv_tensor[1]

1819
1820
1821
            send_recv_reqs = flash_attn_p2p_communicate(
                rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
            )
1822

1823
            kv = p2p_comm_buffers[i % 2][0]
1824
            # In reversed order of fwd
1825
            if causal:
1826
                if i == (cp_size - 1):
1827
                    if ctx.use_fused_attention:
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                            kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
                            # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
1844
1845
                        elif ctx.qkv_format == "thd":
                            q_, kv_, out_, dout_ = q, kv, out, dout
1846
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1847
                        if attn_dbias is not None:
1848
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1849
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_k,
                            cu_seqlens_q,
                            cu_seqlens_k,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
1862
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1863
1864
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1865
1866
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1867
                            qkv_layout=qkv_layout,
1868
                            attn_mask_type=ctx.attn_mask_type,
1869
                            attn_bias_type=ctx.attn_bias_type,
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, 0]
                        _flash_attn_backward(
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
                            cu_seqlens_q,
                            cu_seqlens_k,
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_k,
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            True,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
1902
                        )
1903
                elif i >= (cp_size - rank - 1):
1904
                    if ctx.use_fused_attention:
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous()
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
                            # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
                            kv_ = kv[:, 0, ...].contiguous()
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
1921
1922
1923
1924
                        elif ctx.qkv_format == "thd":
                            q_, out_, dout_ = q, out, dout
                            # [2, t, np, hn] -> [2, t/2, np, hn]
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
1925
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1926
                        if attn_dbias is not None:
1927
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1928
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_k // 2,
                            cu_seqlens_q,
                            cu_seqlens_k // 2,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
1941
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1942
1943
1944
1945
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
1946
1947
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1948
                            qkv_layout=qkv_layout,
1949
                            attn_mask_type="padding" if padding else "no_mask",
1950
                            attn_bias_type=ctx.attn_bias_type,
1951
1952
1953
1954
1955
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
1956
1957
1958
1959
1960
1961
                        if ctx.qkv_format == "thd":
                            # [2, t, np, hn] -> [2, t/2, np, hn]
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
                        else:
                            # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
1962
1963
1964
1965
1966
1967
1968
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
                            cu_seqlens_q,
                            cu_seqlens_k // 2,
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_k // 2,
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
1987
1988
1989
                        )
                else:
                    if ctx.use_fused_attention:
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous()
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                            kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous()
                            dout_ = dout[:, 1, ...].contiguous()
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            q_ = q[1].contiguous()
                            # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            out_ = out[1].contiguous()
                            dout_ = dout[1].contiguous()
2006
2007
2008
2009
2010
2011
                        elif ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
                            kv_ = kv
2012
                        aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
2013
                        if attn_dbias is not None:
2014
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2015
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
                            ctx.max_seqlen_q // 2,
                            ctx.max_seqlen_k,
                            cu_seqlens_q // 2,
                            cu_seqlens_k,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
2028
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
2029
2030
2031
2032
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2033
2034
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2035
                            qkv_layout=qkv_layout,
2036
                            attn_mask_type="padding" if padding else "no_mask",
2037
                            attn_bias_type=ctx.attn_bias_type,
2038
2039
                        )
                    else:
2040
2041
2042
2043
2044
2045
                        if ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
2046
2047
2048
2049
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
2050
2051
2052
2053
2054
2055
2056
                        if ctx.qkv_format == "thd":
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                            dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
2057
2058
2059
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse_,
                            dq_,
                            dkv_[0],
                            dkv_[1],
                            cu_seqlens_q // 2,
                            cu_seqlens_k,
                            ctx.max_seqlen_q // 2,
                            ctx.max_seqlen_k,
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2078
2079
2080
                        )
            else:
                if ctx.use_fused_attention:
2081
                    aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2082
                    if attn_dbias is not None:
2083
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2084
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_k,
                        cu_seqlens_q,
                        cu_seqlens_k,
                        q,
                        kv[0],
                        kv[1],
                        out,
                        dout,
                        TE_DType[q.dtype],
                        TE_DType[kv.dtype],
                        aux_ctx_tensors,
2097
                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
2098
2099
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2100
2101
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
2102
                        qkv_layout=qkv_layout,
2103
                        attn_mask_type=ctx.attn_mask_type,
2104
                        attn_bias_type=ctx.attn_bias_type,
2105
2106
2107
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
2108
2109
                    q_ = q.view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
2110
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
2111
2112
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
2113
                    # [b, sq, np, hn] -> [b*sq, np, hn]
2114
2115
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
2116
2117
                    if _flash_attn_2_3_plus:
                        fa_optional_backward_kwargs["window_size"] = [-1, -1]
2118
                    _flash_attn_backward(
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
                        dout_,
                        q_,
                        kv_[0],
                        kv_[1],
                        out_,
                        softmax_lse,
                        dq_,
                        dkv_[0],
                        dkv_[1],
                        cu_seqlens_q,
                        cu_seqlens_k,
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_k,
                        ctx.dropout_p,
                        ctx.softmax_scale,
                        False,
2135
                        rng_state=rng_states[cp_size - i - 1],
2136
                        **fa_optional_backward_kwargs,
2137
2138
                    )

2139
            if i >= (cp_size - rank - 1) or not causal:
2140
2141
2142
2143
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
2144
2145
2146
2147
2148
2149
                if ctx.qkv_format == "bshd":
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
                elif ctx.qkv_format == "sbhd":
                    # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
                    dq_ = dq_.view(-1, *dq.shape[-3:])
2150

2151
            if causal:
2152
                if i > (cp_size - rank - 1):
2153
                    dq.add_(dq_)
2154
2155
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
2156
2157
                        dq.copy_(dq_)
                    else:
2158
2159
2160
2161
2162
2163
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
2164
2165
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add")
2166
                elif i > 0:
2167
2168
2169
2170
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
2171
2172
                    elif ctx.qkv_format == "thd":
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add")
2173
                else:
2174
2175
2176
2177
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
2178
2179
                    elif ctx.qkv_format == "thd":
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy")
2180
2181
2182
2183
2184
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
2185

2186
            if attn_dbias is not None:
2187
                idx = (rank + i + 1) % cp_size
2188
                if i == (cp_size - 1) or not causal:
2189
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
2190
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2191
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
2192
2193
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
2194
2195
2196
2197
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
2198
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2199
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
2200
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
2201

2202
2203
2204
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
2205

2206
            dkv = p2p_comm_buffers[(i + 1) % 2][1]
2207
2208
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
2209
            if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
2210
2211
2212
2213
2214
2215
                if ctx.qkv_format == "bshd":
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                elif ctx.qkv_format == "sbhd":
                    # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
2216
2217
2218
2219
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
2220

2221
            if causal:
2222
                if i == (cp_size - 1):
2223
                    if rank == 0:
2224
2225
2226
2227
2228
2229
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
2230
2231
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy")
2232
2233
                    else:
                        dkv.add_(dkv_)
2234
2235
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
2236
2237
2238
2239
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
2240
2241
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none")
2242
                    else:
2243
2244
2245
2246
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
2247
2248
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none")
2249
2250
2251
2252
2253
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
2254
2255
2256
2257
2258
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

2259
        if causal:
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                dq = dq.view(q.shape[0], -1, *q.shape[-2:])
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                dq = dq.view(-1, *q.shape[-3:])
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:])

        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
        return (
            None,
            dq,
            dkv[0],
            dkv[1],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            attn_dbias,
            None,
            None,
        )
2298
2299
2300


def attn_forward_func_with_cp(
2301
2302
2303
2304
2305
2306
2307
2308
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
2309
2310
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
2322
2323
) -> torch.Tensor:
    """Attention implementation with context parallelism"""
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert (
        qkv_format != "thd"
        or not use_fused_attention
        or attn_mask_type in ["padding", "padding_causal"]
    ), (
        f"Context parallelism is not supported for {attn_mask_type} mask type and "
        f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
    )
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
2344
    out = AttnFuncWithCP.apply(
2345
2346
2347
2348
2349
2350
2351
2352
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
2353
2354
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
        dropout_p,
        cp_group,
        cp_global_ranks,
        cp_stream,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
2366
2367
2368
2369
    )
    return out


2370
2371
2372
2373
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
2374

2375
2376
2377
    def __init__(
        self,
        dim: int,
2378
        rotary_percent: float = 1.0,
2379
2380
2381
2382
2383
2384
2385
2386
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
2387
2388
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
2389
2390
2391
2392
2393
2394
2395
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
2396
2397
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
2398
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
2399
2400
2401
2402
2403
2404
2405
        inv_freq = 1.0 / (
            10000
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
2406
        self.register_buffer("inv_freq", inv_freq)
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
2420
2421
2422
2423
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
2424

2425
2426
2427
2428
2429
2430
2431
2432
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
2433
2434
2435
2436
2437
2438
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

2439
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
2440
2441
2442
2443
2444
2445
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
2464
2465
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
2466
2467
2468
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
2469
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
2480
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


2496
2497
2498
2499
2500
2501
2502
2503
2504
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


2505
def apply_rotary_pos_emb(
2506
2507
2508
2509
2510
2511
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
2512
    """
2513
    Apply rotary positional embedding tensor to the input tensor.
2514

2515
2516
2517
    Parameters
    ----------
    t: torch.Tensor
2518
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
2531
    """
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

2543
2544
2545
2546
2547
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
2548
2549
2550
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
2551
    freqs = freqs[:cur_seq_len]
2552
    if tensor_format == "bshd":
2553
2554
2555
2556
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
2557

2558
2559
2560
2561
2562
2563
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
2564
    t = (t * cos_) + (_rotate_half(t) * sin_)
2565
2566
2567
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
2568
class _SplitAlongDim(torch.autograd.Function):
2569
2570
2571
    """"""

    @staticmethod
2572
2573
2574
2575
2576
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
2577
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
2578
2579
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
2580
        if isinstance(mixed_x_layer, Float8Tensor):
2581
2582
2583
2584
2585
2586
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
                    data=x,
                )
                for x in torch.split(
2587
2588
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
2589
2590
2591
2592
                    dim=split_dim,
                )
            )
        return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
2593
2594

    @staticmethod
2595
    def backward(ctx, *grad_outputs):
2596
2597
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
2598
2599
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
2600
2601
2602
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
2603
2604
2605
2606
2607
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

2608
2609
2610
2611
2612
2613
2614
2615
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
2616
2617
2618
2619
2620
2621
2622
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
2623
2624
2625
                    noop_ok = False
                    break
            if noop_ok:
2626
2627
2628
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
2629
2630
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
2631
2632
2633
2634
2635
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
2636
2637
2638
2639
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
2640
2641
2642
2643
2644
2645
2646
            return (
                Float8Tensor.make_like(
                    grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim)
                ),
                None,
                None,
            )
2647
2648
        noop_ok = True
        strides = grad_outputs[0].stride()
2649
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
2650
        shape = list(grad_outputs[0].shape)
2651
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
2652
2653
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
2654
2655
2656
2657
2658
2659
2660
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
2661
2662
2663
                noop_ok = False
                break
        if noop_ok:
2664
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
2665
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
2666
            new_shape[split_dim] = sum(split_sizes)
2667
2668
2669
2670
2671
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
2672
            )
cyanguwa's avatar
cyanguwa committed
2673
            return ret, None, None
2674

2675
        return torch.cat(grad_outputs, dim=split_dim), None, None
2676
2677
2678
2679
2680
2681
2682
2683
2684


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
2685
        softmax_scale: float,
2686
2687
2688
2689
2690
2691
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

2692
        self.softmax_scale = softmax_scale
2693
2694
2695
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

2696
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
2697
2698
2699
2700
2701
2702

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

2703
2704
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
2705
2706
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
2707

2708
2709
2710
2711
2712
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2713
        qkv_layout: str = "sbh3d",
2714
2715
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
2716
        attn_mask_type: str = "causal",
2717
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2718
2719
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
2720
        alibi_slopes: Optional[torch.Tensor] = None,
2721
    ) -> torch.Tensor:
2722
        """Unfused attention fprop"""
2723
2724
2725
2726
2727
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
2728
            # convert to sbhd and use sbhd implementation for now
2729
2730
2731
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
2732

2733
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
2734
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
2735
2736
2737
2738
2739
2740
2741
2742
2743

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

2744
        if key_layer.shape[2] != query_layer.shape[2]:
2745
2746
2747
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
2748
            key_layer = key_layer.repeat_interleave(
2749
2750
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
2751
            value_layer = value_layer.repeat_interleave(
2752
2753
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
2754

2755
        # [sq, b, np, hn] -> [sq, b * np, hn]
2756
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
2757
2758
2759
2760
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
2761
2762
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
2763
2764
2765
2766
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
2767
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
2768
2769
2770
            device=torch.cuda.current_device(),
        )

2771
2772
2773
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

2774
        scale = self.softmax_scale
2775
        if apply_qk_layer_scaling:
2776
            scale /= self.layer_number
2777
2778

        # Raw attention scores. [b * np, sq, sk]
2779
2780
2781
2782
2783
2784
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
2785
                alpha=scale,
2786
2787
2788
2789
2790
2791
2792
2793
            )

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
2794
2795
2796
2797
            matmul_result = (
                matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias
            ).view(-1, output_size[2], output_size[3])
2798
            matmul_result *= scale
2799

2800
2801
2802
2803
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
2804
                _, core_attention_bias = get_alibi(
2805
2806
2807
2808
2809
                    output_size[1],
                    output_size[2],
                    output_size[3],
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
2810
                )
2811
2812
2813
2814
2815
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
2816
                alpha=scale,
2817
            )
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
            matmul_result = (
                (
                    matmul_result.view(
                        output_size[0], output_size[1], output_size[2], output_size[3]
                    )
                    + core_attention_bias
                )
                .view(-1, output_size[2], output_size[3])
                .to(dtype=query_layer.dtype)
            )
2828
2829
2830
2831
2832
2833

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
2834
        attention_probs = self.scale_mask_softmax(
2835
2836
            attention_scores, attention_mask, attn_mask_type, softmax_scale
        )
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
2853
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
2854
2855

        # change view [b * np, sq, sk]
2856
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
2857
2858
2859
2860
2861
2862
2863

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

2864
        if qkv_format == "sbhd":
2865
2866
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
2867

2868
2869
2870
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

2871
        if qkv_format == "bshd":
2872
2873
2874
2875
2876
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
2877
2878
2879
2880
2881
2882

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
2883
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
2884
2885

    @staticmethod
2886
2887
2888
2889
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
2890
        value_layer: torch.Tensor,
2891
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
2903
2904
2905
2906
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
2907
        dv: torch.Tensor,
2908
2909
2910
2911
2912
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

2913

2914
def get_qkv_layout(
2915
2916
2917
2918
2919
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
2920
    """Get qkv layout.
2921

2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
        `d` head size, and `t` the total number of sequences in a batch, i.e.
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
2950

2951
2952
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
2953

2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
        stride = k.stride()
        check_strides_kv = all(stride == x.stride() for x in [k, v])

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = all(shape == x.shape for x in [k, v])

        last_dim_size = q.shape[-1]
2971
2972
2973
        check_last_dim_offsets_qkv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
2974
        last_dim_size = k.shape[-1]
2975
2976
2977
        check_last_dim_offsets_kv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])
        )
2978
2979

        last_two_dims_size = q.shape[-1] * q.shape[-2]
2980
2981
2982
        check_last_two_dims_offsets_qkv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
2983
        last_two_dims_size = k.shape[-1] * k.shape[-2]
2984
2985
2986
        check_last_two_dims_offsets_kv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])
        )
2987

2988
2989
2990
2991
        if (
            check_ptrs_qkv
            and check_strides_qkv
            and check_shapes_qkv
2992
            and check_last_two_dims_offsets_qkv
2993
2994
            and not check_last_dim_offsets_qkv
        ):
2995
            # sb3hd, bs3hd, t3hd
2996
2997
2998
2999
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
        elif (
            check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv
        ):
3000
            # sbh3d, bsh3d, th3d
3001
3002
3003
3004
3005
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
        elif (
            check_ptrs_kv
            and check_strides_kv
            and check_shapes_kv
3006
            and check_last_two_dims_offsets_kv
3007
3008
            and not check_last_dim_offsets_kv
        ):
3009
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
3010
3011
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv:
3012
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
3013
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
3014
3015
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
3016
            qkv_layout = "_".join(list([qkv_format]) * 3)
3017
        else:
3018
            qkv_layout = "not_supported"
3019
3020
3021
3022

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
3023
    if qkv_layout == "not_supported":
3024
3025
3026
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
3027
    if qkv_layout == "not_supported":
3028
3029
        raise Exception("The provided qkv memory layout is not supported!")

3030
    return qkv_layout, q, k, v
3031

3032

3033
def check_set_window_size(
3034
3035
3036
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
3037
3038
3039
3040
3041
3042
3043
3044
    """Check if sliding window size is compliant with attention mask type.
    If not, set it to the appropriate size.

         attn_mask_type                              |   window_size
    -------------------------------------------------------------------------
    no_mask, padding, arbitrary                      | (-1, -1) or (>=0, >=0)
    causal, padding_causal                           | (-1,  0) or (>=0, 0)
    causal_bottom_right, padding_causal_bottom_right | (-1,  0) or (>=0, 0)
3045
    """
3046
    orig_window_size = window_size
3047
    if "causal" in attn_mask_type:
3048
3049
3050
        if orig_window_size is None or (
            orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
        ):
3051
            window_size = (-1, 0)
3052
3053
3054
3055
3056
3057
3058
3059
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
        elif orig_window_size[0] >= 0:
            window_size = (orig_window_size[0], 0)
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
3060
        else:
3061
3062
3063
3064
3065
3066
3067
            assert False, (
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
    elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
        if orig_window_size is None or (
            orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
        ):
3068
            window_size = (-1, -1)
3069
3070
3071
            warnings.warn(
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
3072
        elif orig_window_size[0] < 0 or orig_window_size[1] < 0:
3073
3074
3075
3076
3077
            assert False, (
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
    else:
        assert False, "Invalid attn_mask_type: " + attn_mask_type
3078
    return window_size
3079

3080

3081
class FlashAttention(torch.nn.Module):
3082
    """Dot product attention, using HazyResearch flash-attn package:
3083
    https://github.com/Dao-AILab/flash-attention
3084
3085
3086
3087
    """

    def __init__(
        self,
3088
        softmax_scale: float,
3089
3090
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
3091
3092
        attention_type: str = "self",
        layer_number: Optional[int] = None,
3093
        deterministic: bool = False,
3094
3095
3096
3097
3098
3099
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
3100
3101
3102
        assert (
            _flash_attn_version <= _flash_attn_max_version
        ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
3103

3104
        self.softmax_scale = softmax_scale
3105
3106
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
3107
3108
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
3109
        self.deterministic = deterministic
3110
3111
3112
3113
3114
3115

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3116
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3117
3118
3119
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
3120
3121
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
3122
        attn_mask_type: str = "causal",
3123
        window_size: Optional[Tuple[int, int]] = None,
3124
        alibi_slopes: Optional[torch.Tensor] = None,
3125
        cp_group: Optional[dist_group_type] = None,
3126
        cp_global_ranks: List[int] = None,
3127
        cp_stream: torch.cuda.Stream = None,
3128
3129
3130
3131
    ) -> torch.Tensor:
        """flash-attn fprop"""

        assert (
3132
3133
3134
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
3135
        ), "FlashAttention currently only supports FP16 and BF16."
3136
3137
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
3138
        ), "FlashAttention currently only supports CUDA tensors."
3139
3140
        assert (
            qkv_layout in QKVLayouts
3141
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
3142

3143
3144
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

3145
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
3146

3147
        if qkv_format == "sbhd":
3148
            # For now just 128, will make it more general in the future
3149
3150
3151
3152
3153
3154
3155
3156
            if (
                query_layer.shape[-1] == 128
                and query_layer.shape[0] * query_layer.shape[1] >= 512
                and qkv_layout == "sbh3d"
            ):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                    query_layer, key_layer, value_layer
                )
3157
            else:
3158
3159
3160
3161
3162
3163
3164
                query_layer, key_layer, value_layer = [
                    x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        elif qkv_format in ["bshd", "thd"]:
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
3165

3166
        batch_size = query_layer.shape[0]
3167

3168
        if qkv_format in ["sbhd", "bshd"]:
3169
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
3170
3171
3172
3173
3174
3175
3176
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

3177
            if "padding" in attn_mask_type:
3178
                assert not context_parallel, "Padding mask not supported with context parallelism!"
3179
3180
3181
3182
3183

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
3184
                    if cu_seqlens_q is None:
3185
3186
3187
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
3188
3189
3190
3191
3192
3193
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
3194
3195
                    )
                else:
3196
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
3197
3198
3199
3200
3201
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
3202
3203
3204
3205
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
3206
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
3207
            else:
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
3221
3222
3223
3224
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
3225
3226
3227
3228
3229
3230
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
3231

3232
        if context_parallel:
3233
3234
3235
3236
            assert window_size in (
                (-1, -1),
                (-1, 0),
            ), "Sliding window attention is not supported with context parallelism."
3237
3238
3239
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
3240
            with self.attention_dropout_ctx():
3241
                output = attn_forward_func_with_cp(
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                    None,
                    None,
3252
                    self.attention_dropout if self.training else 0.0,
3253
3254
3255
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
3256
                    softmax_scale=self.softmax_scale,
3257
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
3258
                    attn_mask_type=attn_mask_type,
3259
                    deterministic=self.deterministic,
3260
3261
                )
        else:
3262
3263

            from .cpu_offload import CPUOffloadEnabled
3264

3265
3266
3267
3268
3269
3270
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

3271
            with self.attention_dropout_ctx():
3272
                fa_optional_forward_kwargs = {}
3273
3274
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
3275
3276
3277
3278
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
3279
                output = flash_attn_forward_func(
3280
3281
3282
3283
3284
3285
3286
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
3287
                    self.attention_dropout if self.training else 0.0,
3288
3289
                    softmax_scale=self.softmax_scale,
                    causal="causal" in attn_mask_type,
3290
                    **fa_optional_forward_kwargs,
3291
                )
3292

3293
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
3294
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
3295

3296
        if qkv_format == "sbhd":
3297
3298
            # (bs)hd -> bs(hd) -> sb(hd)
            output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
3299
        elif qkv_format == "bshd":
3300
3301
            # (bs)hd -> bs(hd)
            output = output.view(batch_size, max_seqlen_q, -1).contiguous()
3302
        elif qkv_format == "thd":
3303
3304
            # thd -> t(hd)
            output = output.view(output.shape[0], -1).contiguous()
3305
3306

        return output
3307

3308

3309
def _combine_tensors(
3310
3311
3312
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
3313
3314
3315
3316
3317
3318
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
3319
    new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
3320
    if isinstance(tensors[0], Float8Tensor):
3321
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
3322
3323
3324
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
3325
3326
3327
3328
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor)
3329
    else:
3330
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
3331
        combined_tensor.set_(
3332
3333
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
        )
3334
3335

    return combined_tensor
3336

3337

3338
3339
3340
3341
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
3342
3343
3344
3345
3346
    def forward(
        ctx,
        is_training,
        max_seqlen,
        cu_seqlens,
3347
        cu_seqlens_padded,
3348
3349
3350
3351
3352
3353
3354
3355
3356
        qkv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
3357
        window_size,
3358
3359
3360
3361
3362
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
3363
        deterministic,
3364
    ):
3365
        logger = logging.getLogger("FusedAttnFunc_qkvpacked")
3366
        if fp8:
3367
            logger.debug("Running forward in FP8")
3368
            if fp8_meta["recipe"].fp8_mha:
3369
                assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
3370
3371
3372
3373
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
3374
3375
3376
3377
3378
            qkv_group = len(qkv_layout.split("_"))
            assert qkv_group == 1, (
                "qkv layout should conform to 3hd or h3d, e.g. sb3hd,                 but found"
                f" {qkv_layout}."
            )
3379
3380
3381
3382
            if fp8_meta["recipe"].fp8_mha:
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
3383
3384
3385
                qkv_fp8 = cast_to_fp8(
                    qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(qkv.shape)
3386
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
3387
3388
3389
3390
3391
3392
3393
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
3394
                cu_seqlens_padded,
3395
3396
3397
3398
3399
3400
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
3401
3402
3403
3404
3405
3406
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
3407
                window_size,
3408
3409
                rng_gen,
            )
3410
            if fp8_meta["recipe"].fp8_mha:
3411
3412
                out_ret = Float8Tensor(
                    data=out_fp8,
3413
3414
3415
3416
3417
3418
3419
3420
3421
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3422
3423
3424
3425
3426
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
3427
3428
3429
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
3430
3431
                qkv = cast_from_fp8(
                    qkv_c._data,
3432
                    fp8_meta["scaling_fwd"],
3433
3434
3435
3436
                    META_QKV,
                    fp8_dtype_forward,
                    TE_DType[qkv.dtype],
                ).view(qkv.shape)
3437
3438
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3439
3440
3441
3442
3443
3444
3445
3446
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
            fp8_tensors = (
                qkv_fp8,
                out_fp8,
3447
                fp8_meta["scaling_fwd"].scale.clone(),
3448
3449
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
3450
        else:
3451
            logger.debug("Running forward in %s", qkv.dtype)
3452
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
3453
3454
3455
3456
3457
3458
3459
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
3460
                cu_seqlens_padded,
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
3473
                window_size,
3474
3475
                rng_gen,
            )
3476
3477
3478
3479
3480
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
3481
        ctx.save_for_backward(
3482
            *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
3483
        )
3484
        ctx.fp8_meta = fp8_meta
3485
3486
3487
3488
3489
3490
3491
3492
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
3493
        ctx.window_size = window_size
3494
        ctx.fused_attention_backend = (
3495
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
3496
        )
3497
        ctx.use_FAv2_bwd = use_FAv2_bwd
3498
        ctx.deterministic = deterministic
3499

3500
        return out_ret
3501
3502
3503

    @staticmethod
    def backward(ctx, d_out):
3504
        logger = logging.getLogger("FusedAttnFunc_qkvpacked")
3505
        if ctx.fp8_meta["recipe"].fp8_mha:
3506
3507
3508
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
3509
3510
3511
            d_out_f8tensor = d_out
            d_out = d_out._data

3512
        d_out = d_out.contiguous()
3513
3514
3515
3516
        (
            qkv,
            out,
            cu_seqlens,
3517
            cu_seqlens_padded,
3518
3519
3520
3521
3522
3523
            qkv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
3524
3525
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
3526
        if ctx.use_FAv2_bwd:
3527
            softmax_lse, rng_state = aux_ctx_tensors
3528
3529
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
3530
3531
3532
            d_out, q, k, v, out = [
                maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
            ]
3533
            flash_attn_cuda_bwd(
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dqkv[:, 0],
                dqkv[:, 1],
                dqkv[:, 2],
                cu_seqlens,
                cu_seqlens,
                ctx.max_seqlen,
                ctx.max_seqlen,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
3553
            )
3554
            dqkv = dqkv[..., : d_out.shape[-1]]
3555
        else:
3556
3557
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
3558
                    logger.debug("Running backward in FP8")
3559
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
3560
                    fp8_dtype_backward = get_fp8_te_dtype(
3561
3562
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
3563
3564
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
3565
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
3566
3567
3568
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
3569
3570
3571
3572
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
3573
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
3574
3575
3576
3577
3578
3579
3580
3581
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
3582
                        ctx.fused_attention_backend,
3583
                        cu_seqlens_padded,
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
3600
3601
                        ctx.window_size,
                        ctx.deterministic,
3602
                    )
3603
                    if ctx.fp8_meta["recipe"].fp8_mha:
3604
3605
                        dqkv = Float8Tensor(
                            data=dqkv_fp8,
3606
3607
3608
3609
3610
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
3611
                        )
3612
                    else:
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
                        dqkv_c_fp8 = dqkv_fp8.view(
                            -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                        )
                        dqkv = cast_from_fp8(
                            dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dqkv_fp8.shape)
3623
                else:
3624
                    logger.debug("Running backward in %s", qkv.dtype)
3625
3626
3627
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
3628
3629
3630
3631
3632
3633
3634
3635
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
3636
                        ctx.fused_attention_backend,
3637
                        cu_seqlens_padded,
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
3654
3655
                        ctx.window_size,
                        ctx.deterministic,
3656
                    )
3657

3658
3659
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
            return (
                None,
                None,
                None,
                None,
                dqkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
3681
3682
                None,
                None,
3683
            )
3684
        # else, return (dqkv, dbias)
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
        return (
            None,
            None,
            None,
            None,
            dqkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3706
3707
            None,
            None,
3708
        )
3709

3710

3711
3712
3713
3714
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
3715
3716
3717
3718
3719
3720
3721
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
3722
3723
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
        q,
        kv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
3734
        window_size,
3735
3736
3737
3738
3739
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
3740
        deterministic,
3741
    ):
3742
        logger = logging.getLogger("FusedAttnFunc_kvpacked")
3743
        if fp8:
3744
            logger.debug("Running forward in FP8")
3745
            if fp8_meta["recipe"].fp8_mha:
3746
3747
3748
                assert isinstance(q, Float8Tensor) and isinstance(
                    kv, Float8Tensor
                ), "q/kv must be Float8Tensors for FP8 MHA."
3749
3750
3751
3752
3753
3754
3755
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
3756
3757
3758
3759
3760
3761
3762
3763
                qkv_group = len(qkv_layout.split("_"))
                assert qkv_group == 2, (
                    "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd,              "
                    f"       but found {qkv_layout}."
                )
                q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
                    q.shape
                )
3764
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
3765
3766
3767
                kv_fp8 = cast_to_fp8(
                    kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(kv.shape)
3768
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                kv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
3779
3780
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
3781
3782
3783
3784
3785
3786
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
3787
3788
3789
3790
3791
3792
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
3793
                window_size,
3794
3795
                rng_gen,
            )
3796
            if fp8_meta["recipe"].fp8_mha:
3797
3798
                out_ret = Float8Tensor(
                    data=out_fp8,
3799
3800
3801
3802
3803
3804
3805
3806
3807
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3808
3809
3810
3811
3812
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
3813
3814
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
3815
3816
3817
                q = cast_from_fp8(
                    q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype]
                ).view(q.shape)
3818
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
3819
3820
                kv = cast_from_fp8(
                    kv_c._data,
3821
                    fp8_meta["scaling_fwd"],
3822
3823
3824
3825
                    META_QKV,
                    fp8_dtype_forward,
                    TE_DType[kv.dtype],
                ).view(kv.shape)
3826
3827
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3828
3829
3830
3831
3832
3833
3834
3835
3836
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
            fp8_tensors = (
                q_fp8,
                kv_fp8,
                out_fp8,
3837
                fp8_meta["scaling_fwd"].scale.clone(),
3838
3839
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
3840
        else:
3841
            logger.debug("Running forward in %s", q.dtype)
3842
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                kv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
3853
3854
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
3867
                window_size,
3868
3869
                rng_gen,
            )
3870
3871
3872
3873
3874
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
3875
3876
3877
3878
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
3879
3880
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
3881
3882
3883
            *fp8_tensors,
            *aux_ctx_tensors,
        )
3884
        ctx.fp8_meta = fp8_meta
3885
3886
3887
3888
3889
3890
3891
3892
3893
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
3894
        ctx.window_size = window_size
3895
        ctx.fused_attention_backend = (
3896
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
3897
        )
3898
        ctx.use_FAv2_bwd = use_FAv2_bwd
3899
        ctx.deterministic = deterministic
3900

3901
        return out_ret
3902
3903
3904

    @staticmethod
    def backward(ctx, d_out):
3905
        logger = logging.getLogger("FusedAttnFunc_kvpacked")
3906
        if ctx.fp8_meta["recipe"].fp8_mha:
3907
3908
3909
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
3910
3911
3912
            d_out_f8tensor = d_out
            d_out = d_out._data

3913
        d_out = d_out.contiguous()
3914
3915
3916
3917
3918
3919
        (
            q,
            kv,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
3920
3921
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
3922
3923
3924
3925
3926
3927
3928
            q_fp8,
            kv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
3929
3930
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
3931
        if ctx.use_FAv2_bwd:
3932
            softmax_lse, rng_state = aux_ctx_tensors
3933
3934
3935
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
3936
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
3937
            flash_attn_cuda_bwd(
3938
3939
3940
3941
3942
3943
3944
3945
3946
3947
3948
3949
3950
3951
3952
3953
3954
3955
3956
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dkv[:, 0],
                dkv[:, 1],
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
3957
            )
3958
3959
            dq = dq[..., : d_out.shape[-1]]
            dkv = dkv[..., : d_out.shape[-1]]
3960
        else:
3961
3962
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
3963
                    logger.debug("Running backward in FP8")
3964
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
3965
                    fp8_dtype_backward = get_fp8_te_dtype(
3966
3967
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
3968
3969
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
3970
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
3971
3972
3973
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
3974
3975
3976
3977
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
3978
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988
3989
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        kv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
3990
                        ctx.fused_attention_backend,
3991
3992
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
3993
3994
3995
3996
3997
3998
3999
4000
4001
4002
4003
4004
4005
4006
4007
4008
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4009
4010
                        ctx.window_size,
                        ctx.deterministic,
4011
                    )
4012
                    if ctx.fp8_meta["recipe"].fp8_mha:
4013
4014
                        dq = Float8Tensor(
                            data=dq_fp8,
4015
4016
4017
4018
4019
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4020
4021
4022
                        )
                        dkv = Float8Tensor(
                            data=dkv_fp8,
4023
4024
4025
4026
4027
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4028
                        )
4029
4030
4031
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
4032
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(
                            -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                        )
                        dkv = cast_from_fp8(
                            dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dkv_fp8.shape)
4047
                else:
4048
                    logger.debug("Running backward in %s", q.dtype)
4049
4050
4051
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
4052
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        kv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
4063
                        ctx.fused_attention_backend,
4064
4065
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4066
4067
4068
4069
4070
4071
4072
4073
4074
4075
4076
4077
4078
4079
4080
4081
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4082
4083
                        ctx.window_size,
                        ctx.deterministic,
4084
                    )
4085

4086
4087
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
4088
4089
4090
4091
4092
4093
4094
4095
4096
4097
4098
4099
4100
4101
4102
4103
4104
4105
4106
4107
4108
4109
4110
4111
4112
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
4113
4114
                None,
                None,
4115
            )
4116
        # else, return (dqkv, dbias)
4117
4118
4119
4120
4121
4122
4123
4124
4125
4126
4127
4128
4129
4130
4131
4132
4133
4134
4135
4136
4137
4138
4139
4140
4141
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4142
4143
            None,
            None,
4144
4145
        )

4146

4147
4148
4149
4150
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
4151
4152
4153
4154
4155
4156
4157
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
4158
4159
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
4160
4161
4162
4163
4164
4165
4166
4167
4168
4169
4170
        q,
        k,
        v,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
4171
        window_size,
4172
4173
4174
4175
4176
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
4177
        deterministic,
4178
    ):
4179
        logger = logging.getLogger("FusedAttnFunc")
4180
        if fp8:
4181
            logger.debug("Running forward in FP8")
4182
4183
4184
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
4185
4186
                assert (
                    isinstance(q, Float8Tensor)
4187
                    and isinstance(k, Float8Tensor)
4188
4189
                    and isinstance(v, Float8Tensor)
                ), "q/k/v must be Float8Tensors for FP8 MHA."
4190
4191
4192
4193
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4194
                qkv_group = len(qkv_layout.split("_"))
4195
                if qkv_group == 1:
4196
4197
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
4198
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
4199
4200
4201
4202
                    qkv_fp8 = cast_to_fp8(
                        qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1])
4203
4204
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
4205
4206
4207
4208
4209
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
4210
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4211
4212
4213
4214
                    kv_fp8 = cast_to_fp8(
                        kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1])
4215
4216
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
4217
4218
4219
4220
4221
4222
4223
4224
4225
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    k_fp8 = cast_to_fp8(
                        k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(k.shape)
                    v_fp8 = cast_to_fp8(
                        v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(v.shape)
4226
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
4237
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
4238
4239
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4240
4241
4242
4243
4244
4245
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
4246
4247
4248
4249
4250
4251
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4252
                window_size,
4253
4254
                rng_gen,
            )
4255
            if fp8_meta["recipe"].fp8_mha:
4256
4257
                out_ret = Float8Tensor(
                    data=out_fp8,
4258
4259
4260
4261
4262
4263
4264
4265
4266
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4267
4268
4269
4270
4271
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
4272
4273
4274
4275
            out_save = out_ret

            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4276
                qkv_group = len(qkv_layout.split("_"))
4277
                if qkv_group == 1:
4278
4279
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
4280
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
4281
4282
                    qkv_no_fp8 = cast_from_fp8(
                        qkv_c._data,
4283
                        fp8_meta["scaling_fwd"],
4284
4285
4286
4287
4288
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[qkv.dtype],
                    ).view(qkv.shape)
                    q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1])
4289
4290
                    q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                if qkv_group == 2:
4291
4292
                    q = cast_from_fp8(
                        q._data,
4293
                        fp8_meta["scaling_fwd"],
4294
4295
4296
4297
4298
4299
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
4300
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4301
4302
                    kv_no_fp8 = cast_from_fp8(
                        kv_c._data,
4303
                        fp8_meta["scaling_fwd"],
4304
4305
4306
4307
4308
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[kv.dtype],
                    ).view(kv.shape)
                    k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1])
4309
4310
                    k, v = [x.squeeze(dim) for x in [k, v]]
                if qkv_group == 3:
4311
4312
                    q = cast_from_fp8(
                        q._data,
4313
                        fp8_meta["scaling_fwd"],
4314
4315
4316
4317
4318
4319
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    k = cast_from_fp8(
                        k._data,
4320
                        fp8_meta["scaling_fwd"],
4321
4322
4323
4324
4325
4326
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[k.dtype],
                    ).view(k.shape)
                    v = cast_from_fp8(
                        v._data,
4327
                        fp8_meta["scaling_fwd"],
4328
4329
4330
4331
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[v.dtype],
                    ).view(v.shape)
4332
4333
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)

            fp8_tensors = (
                q_fp8,
                k_fp8,
                v_fp8,
                out_fp8,
4345
                fp8_meta["scaling_fwd"].scale.clone(),
4346
4347
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
4348
        else:
4349
            logger.debug("Running forward in %s", q.dtype)
4350
            out_ret, aux_ctx_tensors = fused_attn_fwd(
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
4361
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
4362
4363
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4375
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4376
                window_size,
4377
4378
                rng_gen,
            )
4379
4380
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
4381

4382
        from .cpu_offload import CPUOffloadEnabled
4383

4384
        if CPUOffloadEnabled:
4385
            tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
4386
            qkv_layout = "sbhd_sbhd_sbhd"
4387
4388
4389
4390
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

4391
4392
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
4393
4394
4395
4396
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
4397
4398
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4399
4400
4401
            *fp8_tensors,
            *aux_ctx_tensors,
        )
4402
        ctx.fp8_meta = fp8_meta
4403
4404
4405
4406
4407
4408
4409
4410
4411
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
4412
        ctx.window_size = window_size
4413
        ctx.fused_attention_backend = (
4414
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
4415
        )
4416
        ctx.use_FAv2_bwd = use_FAv2_bwd
4417
        ctx.deterministic = deterministic
4418

4419
        return out_ret
4420
4421
4422

    @staticmethod
    def backward(ctx, d_out):
4423
        logger = logging.getLogger("FusedAttnFunc")
4424
        if ctx.fp8_meta["recipe"].fp8_mha:
4425
4426
4427
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
4428
4429
4430
            d_out_f8tensor = d_out
            d_out = d_out._data

4431
        d_out = d_out.contiguous()
4432
4433
4434
4435
4436
4437
4438
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
4439
4440
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4441
4442
4443
4444
4445
4446
4447
4448
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
4449
4450
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
4451
        if ctx.use_FAv2_bwd:
4452
            softmax_lse, rng_state = aux_ctx_tensors
4453
4454
4455
4456
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
4457
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
4458
            flash_attn_cuda_bwd(
4459
4460
4461
4462
4463
4464
4465
4466
4467
4468
4469
4470
4471
4472
4473
4474
4475
4476
4477
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
4478
            )
4479
4480
4481
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
4482
        else:
4483
4484
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
4485
                    logger.debug("Running backward in FP8")
4486
4487
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
4488
4489
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
4490
4491
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
4492
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
4493
4494
4495
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
4496
4497
4498
4499
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
4500
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
4501
4502
4503
4504
4505
4506
4507
4508
4509
4510
4511
4512
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
4513
                        ctx.fused_attention_backend,
4514
4515
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4516
4517
4518
4519
4520
4521
4522
4523
4524
4525
4526
4527
4528
4529
4530
4531
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4532
4533
                        ctx.window_size,
                        ctx.deterministic,
4534
                    )
4535

4536
                    if ctx.fp8_meta["recipe"].fp8_mha:
4537
4538
                        dq = Float8Tensor(
                            data=dq_fp8,
4539
4540
4541
4542
4543
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4544
4545
4546
                        )
                        dk = Float8Tensor(
                            data=dk_fp8,
4547
4548
4549
4550
4551
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4552
4553
4554
                        )
                        dv = Float8Tensor(
                            data=dv_fp8,
4555
4556
4557
4558
4559
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4560
                        )
4561
                    else:
4562
                        qkv_group = len(ctx.qkv_layout.split("_"))
4563
                        if qkv_group == 1:
4564
4565
4566
4567
4568
4569
4570
4571
4572
4573
4574
4575
4576
                            dim = ctx.qkv_layout.find("3")
                            dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(
                                -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                            )
                            dqkv = cast_from_fp8(
                                dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1])
4577
4578
4579
4580
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
4581
4582
4583
4584
4585
4586
4587
4588
4589
4590
4591
4592
4593
4594
4595
4596
4597
4598
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
                            dkv = cast_from_fp8(
                                dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1])
4599
4600
4601
4602
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
4603
4604
4605
4606
4607
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
4608
4609
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
4610
4611
4612
4613
4614
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dk_fp8.shape)
4615
4616
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
4617
4618
4619
4620
4621
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dv_fp8.shape)
4622
                else:
4623
                    logger.debug("Running backward in %s", q.dtype)
4624
4625
4626
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
4627
4628
4629
4630
4631
4632
4633
4634
4635
4636
4637
4638
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
4639
                        ctx.fused_attention_backend,
4640
4641
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4642
4643
4644
4645
4646
4647
4648
4649
4650
4651
4652
4653
4654
4655
4656
4657
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4658
4659
                        ctx.window_size,
                        ctx.deterministic,
4660
                    )
4661

4662
4663
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
4664
4665
4666
4667
4668
4669
4670
4671
4672
4673
4674
4675
4676
4677
4678
4679
4680
4681
4682
4683
4684
4685
4686
4687
4688
4689
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
4690
4691
                None,
                None,
4692
            )
4693
        # else, return (dqkv, dbias)
4694
4695
4696
4697
4698
4699
4700
4701
4702
4703
4704
4705
4706
4707
4708
4709
4710
4711
4712
4713
4714
4715
4716
4717
4718
4719
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4720
4721
            None,
            None,
4722
        )
4723

4724

4725
class FusedAttention(torch.nn.Module):
4726
4727
4728
4729
4730
4731
4732
4733
4734
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

4735
4736
4737
4738
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
4739
    | attn_type     | self/cross              | self/cross                     |
4740
    | qkv_layout    |                         |                                |
4741
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
4742
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
4743
4744
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
4745
4746
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
4747
    | dropout       | yes                     | yes                            |
4748
4749
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
4750
    | output dtype  | fp16/bf16               | fp16/bf16                      |
4751
4752
4753
4754
    """

    def __init__(
        self,
4755
        softmax_scale: float,
4756
4757
4758
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
4759
4760
        layer_number: Optional[int] = None,
        deterministic: bool = False,
4761
4762
4763
    ) -> None:
        super().__init__()

4764
        self.logger = logging.getLogger("FusedAttention")
4765
        self.softmax_scale = softmax_scale
4766
4767
4768
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
4769
4770
4771
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
4772
        self.layer_number = 1 if layer_number is None else layer_number
4773
        self.deterministic = deterministic
4774

4775
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
4776
4777
            """
            Temporarily remove fused_attention._extra_state as a missing key
4778
4779
4780
4781
            or an unexpected key when loading TransformerEngine checkpoints.
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
            phased out in TransformerEngine 2.0.
4782
4783
            """
            for key in incompatible_keys.missing_keys:
4784
                if "fused_attention._extra_state" in key:
4785
                    incompatible_keys.missing_keys.remove(key)
4786
4787
4788
4789
4790
4791
4792
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
4793

4794
4795
        self.register_load_state_dict_post_hook(remove_extra_states_check)

4796
    @no_torch_dynamo()
4797
4798
4799
4800
4801
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4802
4803
4804
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
4805
4806
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
4807
4808
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
4809
        attn_mask_type: str = "causal",
4810
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4811
        window_size: Optional[Tuple[int, int]] = None,
4812
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
4813
4814
4815
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
4816
4817
4818
        cp_group: Optional[dist_group_type] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
4819
4820
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
4821
4822
    ) -> torch.Tensor:
        """fused attention fprop"""
4823
4824
4825
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
4826
        assert (
4827
4828
4829
            (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
4830
        ), "FusedAttention only supports FP16 and BF16 data types."
4831
4832
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
4833
        ), "FusedAttention only supports CUDA tensors."
4834
4835
        assert (
            qkv_layout in QKVLayouts
4836
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
4837

4838
4839
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

4840
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
4841

4842
4843
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
4844
                batch_size, max_seqlen_q, max_seqlen_kv = (
4845
4846
4847
4848
4849
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
4850
                batch_size, max_seqlen_q, max_seqlen_kv = (
4851
4852
4853
4854
4855
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
            if "padding" in attn_mask_type:
4856
4857
                assert not context_parallel, "Padding mask not supported with context parallelism!"

4858
4859
4860
4861
4862
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
4863
                    if self.attention_type == "self":
4864
4865
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
4866
                    else:
4867
4868
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
4869
            else:
4870
4871
4872
4873
4874
4875
4876
4877
4878
4879
4880
4881
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
4882
4883
4884
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
4885
4886
4887
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
4888
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
4889
4890
4891
4892

        if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
4893
4894
4895

        qkv_dtype = TE_DType[query_layer.dtype]

4896
4897
4898
4899
4900
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
4901
4902

        if context_parallel:
4903
            assert (
4904
4905
4906
4907
4908
4909
4910
4911
                fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
4912
4913
4914
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
4915
4916
4917
4918
4919
4920
4921
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
4922
4923
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
4924
                    self.attention_dropout if self.training else 0.0,
4925
4926
4927
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
4928
                    softmax_scale=self.softmax_scale,
4929
                    qkv_format=qkv_format,
4930
                    attn_mask_type=attn_mask_type,
4931
4932
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
4933
4934
4935
                    use_fused_attention=True,
                )
        else:
4936
4937
4938
4939
4940
            with self.attention_dropout_ctx():
                if fp8:
                    assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                        f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                        " is required for FP8 attention!"
4941
                    )
4942
4943
4944
4945
4946
4947
4948
4949
4950
                    assert (
                        fp8_meta is not None
                    ), "FP8 metadata fp8_meta is required for FP8 attention!"
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
4951
4952
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
4953
4954
4955
4956
4957
4958
4959
4960
4961
4962
4963
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
4964
                    window_size,
4965
4966
4967
4968
4969
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
4970
                    self.deterministic,
4971
                )
4972

4973
4974
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
4975
4976


4977
class DotProductAttention(TransformerEngineBaseModule):
4978
4979
4980
4981
4982
4983
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

4984
        Argument :attr:`attention_mask` in the `forward` call is only used when
4985
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
4986
4987
4988

    .. warning::

4989
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
4990
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
4991
4992
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
4993
4994
4995
4996
4997
4998

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels : int
4999
                number of key-query-value channels per attention head.
5000
5001
5002
5003
5004
5005
5006
5007
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
5008
5009
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
5010
    attn_mask_type: str, default = `causal`
5011
                   type of attention mask passed into softmax operation, options are "`no_mask`",
5012
5013
5014
5015
5016
5017
5018
5019
5020
5021
5022
5023
5024
5025
5026
5027
5028
5029
5030
5031
5032
5033
5034
5035
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
                   "`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
5036
5037
5038
5039
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
5040
5041
5042
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
5043
                be overridden by :attr:`window_size` in `forward` as well.
5044
5045
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
5046
5047
5048
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
5049
5050
5051
5052
5053
5054
5055
5056
5057
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
               `h` the number of heads, `d` head size, and `t` the total number of sequences
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
5058
               For that, please use `get_qkv_layout` to gain the layout information.
5059
5060
5061
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
                `1.0 / math.sqrt(kv_channels)`.
5062
5063
5064
5065
5066
5067
5068
5069
5070

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
5071
5072
5073
5074
5075
5076
5077
5078
5079
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
5080
5081
5082
5083
5084
5085
    """

    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
5086
        num_gqa_groups: Optional[int] = None,
5087
        attention_dropout: float = 0.0,
5088
        qkv_format: str = "sbhd",
5089
        attn_mask_type: str = "causal",
5090
        window_size: Optional[Tuple[int, int]] = None,
5091
5092
5093
5094
5095
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
5096
        attention_type: str = "self",
5097
        cp_group: Optional[dist_group_type] = None,
5098
        cp_global_ranks: List[int] = None,
5099
        cp_stream: torch.cuda.Stream = None,
5100
        softmax_scale: Optional[float] = None,
5101
5102
5103
    ) -> None:
        super().__init__()

5104
        self.logger = logging.getLogger("DotProductAttention")
5105
        self.qkv_format = qkv_format
5106
        attn_mask_type = attn_mask_type.replace(",", "_")
5107
5108
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
5109
        self.attn_mask_type = attn_mask_type
5110
        self.window_size = check_set_window_size(attn_mask_type, window_size)
5111
5112
5113
5114
5115
5116
5117
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
5118
        self.get_rng_state_tracker = get_rng_state_tracker
5119
        self.num_attention_heads = num_attention_heads
5120
        self.layer_number = 1 if layer_number is None else layer_number
5121
5122
5123
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
5124

5125
        self.hidden_size_per_attention_head = kv_channels
5126

5127
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
5128
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
5129

5130
5131
5132
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
5133

5134
        self.rng_states_tracker = None
5135
5136
5137
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
5138
5139
5140
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
5141

5142
5143
        if softmax_scale is None:
            softmax_scale = 1.0 / math.sqrt(kv_channels)
5144

5145
5146
5147
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
5148
        )
5149
5150
5151
5152
5153
5154
5155
5156
5157
5158
5159
5160
5161
5162
5163
5164
5165
5166
5167
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
5168

5169
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
5170
5171
5172
5173

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

5174
5175
5176
5177
5178
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

5179
5180
5181
5182
5183
5184
5185
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
5186

5187
        # Instantiating three types since use of flash-attn and FusedAttention
5188
        # might be ruled out due to forward inputs.
5189
5190
5191
5192
5193
5194
5195
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
5196

5197
        self.unfused_attention = UnfusedDotProductAttention(
5198
5199
            softmax_scale, **attn_kwargs, layer_number=layer_number
        )
5200

5201
5202
5203
5204
5205
5206
5207
5208
5209
5210
5211
5212
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
            when loading older TransformerEngine checkpoints. Will phase out
            this hook in TransformerEngine 2.0.
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

5213
5214
5215
5216
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
5217
        **forward_kwargs: Dict[str, Any],
5218
5219
5220
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

5221
5222
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
5223
5224
5225

        hidden_states = checkpoint(
            custom_forward,
5226
5227
5228
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
5229
            *forward_args,
5230
            **forward_kwargs,
5231
5232
5233
5234
        )

        return hidden_states

5235
5236
5237
5238
5239
5240
    def set_context_parallel_group(
        self,
        cp_group: Union[dist_group_type, None],
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
    ) -> None:
5241
5242
5243
5244
5245
5246
5247
5248
5249
5250
5251
5252
5253
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
5254
5255
5256
5257
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream

5258
    @no_torch_dynamo(recursive=False)
5259
5260
5261
5262
5263
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5264
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5265
5266
5267
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5268
5269
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
5270
5271
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5272
        attn_mask_type: Optional[str] = None,
5273
        window_size: Optional[Tuple[int, int]] = None,
5274
        checkpoint_core_attention: bool = False,
5275
5276
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
5277
        alibi_slopes: Optional[torch.Tensor] = None,
5278
        fast_zero_fill: bool = True,
5279
        inference_params: Optional[InferenceParams] = None,
5280
        is_first_microbatch: Optional[bool] = None,
5281
5282
5283
5284
5285
5286
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

5287
5288
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
5289
5290
5291

        .. note::

5292
5293
5294
            Input tensor :attr:`query_layer` must be of shape
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`,
            :attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer`
5295
            must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
5296
            :attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape
5297
5298
5299
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
            * :attr:`kv_channels`) is returned.

5300
5301
        .. note::

5302
5303
5304
5305
5306
5307
5308
5309
5310
5311
5312
5313
5314
5315
5316
5317
5318
5319
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
5320
5321
5322
5323
5324
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
5325

5326
5327
5328
5329
5330
5331
5332
5333
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
5334
5335
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
5336
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
5337
5338
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
5339
5340
5341
5342
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
5343
5344
5345
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
5346
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
5347
5348
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
5349
5350
5351
5352
5353
5354
5355
5356
5357
5358
5359
5360
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
5361
5362
5363
5364
5365
5366
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
5367
5368
5369
5370
5371
5372
5373
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
5374
        window_size: Optional[Tuple[int, int]], default = `None`
5375
                    Sliding window size for local attention.
5376
5377
5378
5379
5380
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
5381
        core_attention_bias_type: str, default = `no_bias`
5382
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
5383
        core_attention_bias: Optional[torch.Tensor], default = `None`
5384
5385
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
5386
5387
5388
5389
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
5390
        fast_zero_fill: bool, default = `True`
5391
                    Whether to use the fast path to set output tensors to 0 or not.
5392
5393
5394
5395
5396
5397
5398
5399
5400
5401
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
5402
5403
5404
5405
5406
5407
5408
5409
5410
5411
5412
5413
5414
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
5415
        """
5416
5417
5418
5419
5420
5421
5422
5423
5424
5425
5426
        with self.prepare_forward(
            query_layer,
            is_first_microbatch,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:

            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
5427
5428
5429
5430
                        self.logger.WARNING(
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
5431
5432
5433
5434
5435
5436
5437
5438
5439
5440
5441

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
5442

5443
5444
5445
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
5446
5447
5448
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
5449
            assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
5450

5451
5452
5453
5454
5455
5456
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
5457
            assert (
5458
5459
5460
5461
5462
5463
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
5464

5465
5466
5467
5468
            if window_size is None:
                window_size = self.window_size
            window_size = check_set_window_size(attn_mask_type, window_size)

5469
5470
5471
5472
5473
5474
5475
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
5476

5477
5478
            if qkv_format is None:
                qkv_format = self.qkv_format
5479

5480
5481
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
5482

5483
5484
5485
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
5486

5487
5488
5489
5490
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
5491

5492
5493
5494
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
5495

5496
5497
5498
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
5499

5500
5501
5502
5503
5504
5505
5506
5507
5508
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
5509

5510
5511
5512
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
5513

5514
5515
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
5516
5517

            assert (
5518
5519
5520
5521
5522
5523
5524
5525
5526
5527
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
5528
                assert all(
5529
5530
5531
5532
5533
5534
5535
5536
5537
5538
5539
5540
5541
5542
5543
5544
5545
5546
5547
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
                if max_seqlen_q is None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
                if max_seqlen_kv is None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
5548
                batch_size = len(cu_seqlens_q) - 1
5549
5550

            if qkv_format in ["sbhd", "bshd"]:
5551
                assert all(
5552
5553
5554
5555
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
5556
                    batch_size = query_layer.shape[1]
5557
5558
                if qkv_format == "bshd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
5559
                    batch_size = query_layer.shape[0]
5560
5561
5562
5563
5564
5565
5566
5567
5568
5569
5570
5571
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                        the sequence dimention in 'query_layer'!"""
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                        the sequence dimention in 'key_layer' and 'value_layer'!"""
5572
5573
5574
5575
5576
5577
5578
5579
5580
5581
5582
5583
5584
5585
5586
5587
5588
5589
5590
5591
5592
5593
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        if max_seqlen_q == max_seqlen_kv:
                            cu_seqlens_q = get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                    else:
                        cu_seqlens_q = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
                        cu_seqlens_kv = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
5594

5595
5596
5597
5598
5599
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
5600
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
5601
5602
5603
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
5604
                qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
5605
5606
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
5607

5608
5609
5610
5611
5612
5613
5614
5615
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
5616
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
5617
5618
5619
5620
5621
5622
5623
5624
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
5625
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
5626
5627
5628
5629
5630
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

5631
5632
5633
            context_parallel = (
                self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
            )
5634

5635
5636
            core_attention_bias_shape = None
            if core_attention_bias is not None:
5637
                if (
5638
5639
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
5640
                ):
5641
5642
5643
5644
5645
5646
5647
5648
5649
5650
5651
5652
5653
5654
5655
5656
5657
5658
5659
5660
5661
5662
5663
5664
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

            pad_between_seqs = (
                cu_seqlens_q_padded is not None
                and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
            ) or (
                cu_seqlens_kv_padded is not None
                and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
            )
5665

5666
            attention_params = AttentionParams(
5667
5668
5669
5670
5671
5672
5673
5674
5675
5676
5677
5678
5679
5680
5681
5682
5683
5684
5685
5686
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
                head_dim=query_layer.shape[-1],
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
5687
5688
                deterministic=self.deterministic,
                is_training=self.training,
5689
5690
5691
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
5692
5693
5694
5695
5696
5697
5698
5699
5700
5701
5702
5703
5704
5705
5706
5707
5708
5709
5710
5711
5712
            global _attention_backends
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
                ) = get_attention_backend(attention_params)
                if use_flash_attention:
                    self.logger.info("Running with FlashAttention backend")
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
5713
                    )
5714
5715
5716
5717
5718
5719
5720
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
5721

5722
5723
5724
5725
5726
5727
5728
5729
5730
5731
5732
5733
5734
5735
5736
5737
5738
5739
5740
5741
5742
5743
5744
5745
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
5746
                )
5747

5748
            if use_fused_attention:
5749
5750
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
5751
5752
5753
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
5754
5755
5756
5757
5758
5759
5760
                    fu_core_attention_bias_type = "post_scale_bias"
                    _, fu_core_attention_bias = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
5761
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
5762
                    )
5763
5764
5765
5766
5767
5768
5769
5770
5771
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
5772
5773
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
5774
5775
5776
5777
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
5778
                        window_size=window_size,
5779
5780
5781
5782
5783
5784
5785
5786
5787
5788
5789
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
                    )
                return self.fused_attention(
5790
5791
5792
5793
5794
5795
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
5796
5797
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
5798
5799
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
5800
5801
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
5802
                    window_size=window_size,
5803
                    fused_attention_backend=fused_attention_backend,
5804
5805
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
5806
5807
5808
5809
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
5810
5811
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
5812
                )
5813

5814
            from .cpu_offload import CPUOffloadEnabled
5815

5816
5817
5818
5819
5820
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
5821

5822
            if use_unfused_attention:
5823
5824
5825
5826
5827
5828
                if window_size is not None and (
                    window_size[0] != -1 or window_size[1] not in [-1, 0]
                ):
                    attn_mask_type, attention_mask = get_swa_mask(
                        window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
                    )
5829
5830
5831
5832
5833
5834
5835
5836
5837
5838
5839
5840
5841
5842
5843
5844
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
5845
5846
5847
                    query_layer,
                    key_layer,
                    value_layer,
5848
5849
5850
5851
5852
5853
5854
5855
5856
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
5857

5858
            raise Exception("No dot product attention support for the provided inputs!")
5859
5860


5861
5862
5863
5864
5865
5866
5867
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

5868
5869
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
5870

5871
5872
5873
5874
5875
5876
5877
5878
5879
5880
5881
5882
5883
5884
5885
5886
5887
5888
5889
5890
5891
5892
5893
5894
5895
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
5896
5897
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
5898
                   default = `causal`
5899
5900
5901
5902
5903
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
5904
5905
5906
5907
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
5908
5909
5910
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
5911
                be overridden by :attr:`window_size` in `forward` as well.
5912
5913
5914
5915
5916
5917
5918
5919
5920
5921
5922
5923
5924
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
5925
5926
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
5927
5928
5929
5930
5931
5932
5933
5934
5935
5936
5937
5938
5939
5940
5941
5942
5943
5944
5945
5946
5947
5948
5949
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
5950
5951
5952
5953
5954
5955
5956
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
5957
            For that, please use `get_qkv_layout` to gain the layout information.
5958
5959
5960
5961
5962
5963
5964
5965
5966
5967
5968
5969
5970
5971
5972
5973
5974
5975
5976
5977
5978
5979
5980
5981
5982
5983
5984
5985
5986
5987
5988
5989
5990
5991
5992
5993
5994
5995
5996
5997

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
5998
5999
6000
6001
6002
6003
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
6004
6005
6006
6007
6008
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
6009
        layer_number: Optional[int] = None,
6010
        attn_mask_type: str = "causal",
6011
        window_size: Optional[Tuple[int, int]] = None,
6012
6013
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
6014
        num_gqa_groups: Optional[int] = None,
6015
6016
6017
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
6018
        params_dtype: Optional[torch.dtype] = None,
6019
        return_bias: bool = False,
6020
6021
6022
6023
6024
6025
6026
6027
6028
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
6029
        ub_overlap_rs_dgrad: bool = False,
6030
6031
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
6032
        bias: bool = True,
6033
        normalization: str = "LayerNorm",
6034
        device: Union[torch.device, str] = "cuda",
6035
        qkv_format: str = "sbhd",
6036
6037
    ) -> None:
        super().__init__()
6038

6039
        self.qkv_format = qkv_format
6040
        self.attn_mask_type = attn_mask_type
6041
        self.window_size = check_set_window_size(attn_mask_type, window_size)
6042
        self.layer_number = layer_number
6043
6044
6045
6046
6047
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
6048
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
6049
        self.num_attention_heads = num_attention_heads
6050
6051
6052
6053
6054
6055
6056
6057
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
6058
6059
6060
6061
6062

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

6063
6064
6065
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
6066
6067
6068
6069
6070
6071

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
6072
6073
6074
6075
6076
6077
6078
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
6079
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
6080
6081
6082
6083

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
6084
6085
6086
6087
6088
6089
6090

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
6091
            "params_dtype": self.params_dtype,
6092
            "device": device,
6093
6094
6095
6096
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
6097
        if self.attention_type == "self":
6098
6099
            parameters_split = None
            if not fuse_qkv_params:
6100
6101
6102
6103
6104
6105
6106
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
6107
6108
6109
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
6110
                    self.hidden_size_q + 2 * self.hidden_size_kv,
6111
6112
6113
6114
6115
6116
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
6117
                    parameters_split=parameters_split,
6118
6119
6120
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
6121
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
6122
                    ub_overlap_ag=ub_overlap_ag,
6123
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6124
                    ub_name="qkv",
6125
6126
6127
6128
6129
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
6130
                    self.hidden_size_q + 2 * self.hidden_size_kv,
6131
6132
6133
6134
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
6135
                    parameters_split=parameters_split,
6136
6137
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
6138
        elif self.attention_type == "cross":
6139
6140
6141
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
6142
                    self.hidden_size_q,
6143
6144
6145
6146
6147
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
6148
                    parameters_split=("query",) if not fuse_qkv_params else None,
6149
6150
6151
6152
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
6153
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
6154
                    ub_overlap_ag=ub_overlap_ag,
6155
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6156
                    ub_name="qkv",
6157
6158
6159
6160
6161
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
6162
                    self.hidden_size_q,
6163
6164
6165
6166
6167
6168
6169
6170
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
6171
                2 * self.hidden_size_kv,
6172
6173
6174
6175
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
6176
                parameters_split=("key", "value") if not fuse_qkv_params else None,
6177
6178
6179
6180
6181
6182
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
6183
            self.hidden_size_per_attention_head,
6184
6185
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
6186
            qkv_format=self.qkv_format,
6187
6188
6189
6190
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
6191
            layer_number=self.layer_number,
6192
            attention_type=self.attention_type,
6193
6194
6195
6196
        )

        # Linear
        self.proj = Linear(
6197
            self.hidden_size_q,
6198
6199
6200
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
6201
            return_bias=return_bias,
6202
            parallel_mode="row" if set_parallel_mode else None,
6203
6204
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6205
            ub_name="proj",
6206
6207
6208
6209
            **common_gemm_kwargs,
        )

    def _allocate_memory(
6210
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
6211
6212
6213
6214
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
6215
            self.num_gqa_groups_per_partition,
6216
            self.hidden_size_per_attention_head,
6217
            dtype=dtype,
6218
6219
6220
6221
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
6222
6223
6224
6225
6226
6227
6228
6229
6230
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
6231
6232
        self.tp_group = tp_group

6233
    def set_context_parallel_group(
6234
6235
        self,
        cp_group: Union[dist_group_type, None],
6236
        cp_global_ranks: List[int],
6237
6238
        cp_stream: torch.cuda.Stream,
    ) -> None:
6239
6240
6241
6242
6243
6244
6245
6246
6247
6248
6249
6250
6251
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
6252
6253
6254
6255
6256
6257
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
6258

6259
6260
6261
    def forward(
        self,
        hidden_states: torch.Tensor,
6262
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6263
        encoder_output: Optional[torch.Tensor] = None,
6264
        attn_mask_type: Optional[str] = None,
6265
        window_size: Optional[Tuple[int, int]] = None,
6266
6267
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
6268
        inference_params: Optional[InferenceParams] = None,
6269
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6270
6271
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
6272
        alibi_slopes: Optional[torch.Tensor] = None,
6273
        fast_zero_fill: bool = True,
6274
    ) -> Tuple[Union[torch.Tensor, None], ...]:
6275
6276
6277
6278
6279
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

6280
6281
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
6282
6283
6284
6285
6286

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
6287
6288
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
6289
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
6290
6291
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
6292
6293
6294
6295
6296
6297
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
6298
                       default = `None`
6299
6300
6301
6302
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
6303
6304
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
6305
6306
6307
6308
6309
6310
6311
6312
6313
6314
6315
6316
6317
6318
6319
6320
6321
6322
6323
6324
6325
6326
6327
6328
6329
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
6330
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
6331
        core_attention_bias: Optional[torch.Tensor], default = `None`
6332
6333
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
6334
6335
6336
6337
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
6338
6339
6340
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
6341
6342
        # hidden_states: [sq, b, h]

6343
        if attn_mask_type is None:
6344
            attn_mask_type = self.attn_mask_type
6345
6346
        if window_size is None:
            window_size = self.window_size
6347
        window_size = check_set_window_size(attn_mask_type, window_size)
6348

6349
        if "padding" in attn_mask_type and attention_mask is not None:
6350
            for i, _ in enumerate(attention_mask):
6351
6352
6353
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
6354

6355
6356
6357
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
6358

6359
        # =================================================
6360
        # Pre-allocate memory for key-values for inference
6361
6362
6363
6364
        # =================================================

        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
6365
                inf_max_seq_len = inference_params.max_sequence_length
6366
6367
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
6368
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
6369
6370
                )
                inference_value_memory = self._allocate_memory(
6371
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
6372
6373
6374
6375
6376
6377
6378
6379
6380
6381
6382
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

6383
        # ======================
6384
        # Query, Key, and Value
6385
        # ======================
6386

cyanguwa's avatar
cyanguwa committed
6387
6388
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
6389
6390
6391
6392
6393
6394
6395
6396
6397
6398
6399
6400
6401
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6402
                    is_first_module_in_mha=True,  # specific to FP8 MHA
6403
6404
                )

6405
6406
6407
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
6408
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
6409
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
6410
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
6411
6412
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
6413
6414
6415
6416
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
6417
6418
6419
6420
6421
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
6422
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
6423
6424
6425
                )
                # split along third last dimension
                split_dim = -3
6426
6427
6428

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6429
6430
6431
6432
6433
6434
6435
6436
6437
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
6438
                )
6439
            else:
cyanguwa's avatar
cyanguwa committed
6440
                query_layer, key_layer, value_layer = torch.split(
6441
6442
6443
6444
                    mixed_x_layer,
                    (num_queries_per_key_value, 1, 1),
                    dim=split_dim,
                )
cyanguwa's avatar
cyanguwa committed
6445
6446
6447

            # query: -> [sq, b, np, hn]
            # key, value: -> [sq, b, ng, hn]
6448
6449
6450
6451
            query_layer, key_layer, value_layer = (
                x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                for x in (query_layer, key_layer, value_layer)
            )
cyanguwa's avatar
cyanguwa committed
6452
6453
6454

        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
6455
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
6456
                encoder_output,
6457
                is_first_microbatch=is_first_microbatch,
6458
                is_first_module_in_mha=True,  # specific to FP8 MHA
6459
6460
6461
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
6462
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
6463
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
6464
                    self.num_gqa_groups_per_partition,
6465
6466
6467
6468
6469
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
6470
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
6471
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
6472
                    2 * self.num_gqa_groups_per_partition,
6473
6474
6475
6476
6477
6478
6479
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6480
6481
6482
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
6483
6484
6485
                    mixed_kv_layer,
                    split_dim,
                    mixed_kv_layer.shape[split_dim] // 2,
cyanguwa's avatar
cyanguwa committed
6486
                )
6487
            else:
cyanguwa's avatar
cyanguwa committed
6488
                key_layer, value_layer = torch.split(
6489
6490
6491
                    mixed_kv_layer,
                    mixed_kv_layer.shape[split_dim] // 2,
                    dim=split_dim,
cyanguwa's avatar
cyanguwa committed
6492
                )
6493
6494
6495
6496
6497
6498
6499
6500
6501
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
6502
6503
6504
6505
6506
6507
6508
6509
6510
6511
6512
6513
6514
6515
6516

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6517
                    is_first_module_in_mha=True,  # specific to FP8 MHA
6518
6519
6520
6521
6522
6523
6524
6525
6526
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

6527
6528
6529
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
6530

6531
        if rotary_pos_emb is not None:
6532
6533
6534
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
6535
            # duplicate the pos_emb for self attention
6536
            if not isinstance(rotary_pos_emb, tuple):
6537
                rotary_pos_emb = (rotary_pos_emb,) * 2
6538
6539

            q_pos_emb, k_pos_emb = rotary_pos_emb
6540
6541
6542
6543
6544
6545
6546
6547
6548
6549
6550
6551
6552
6553

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

6554
6555
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
6556

6557
6558
6559
6560
        # ===========================
        # Core attention computation
        # ===========================

6561
6562
6563
6564
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
6565
            qkv_format=self.qkv_format,
6566
6567
            cu_seqlens_q=None,
            cu_seqlens_kv=None,
6568
6569
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
6570
            window_size=window_size,
6571
6572
6573
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
6574
            alibi_slopes=alibi_slopes,
6575
            fast_zero_fill=fast_zero_fill,
6576
            inference_params=inference_params,
6577
6578
        )

6579
        # ===================
6580
        # Output. [sq, b, h]
6581
        # ===================
6582

6583
        projection_output = self.proj(
6584
6585
            context_layer,
            is_first_microbatch=is_first_microbatch,
6586
6587
        )

6588
6589
6590
6591
6592
6593
6594
6595
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
6596
        if self.input_layernorm and self.return_layernorm_output:
6597
6598
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]