attention.py 337 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
import math
10
import os
11
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
import warnings
13
import logging
14

15
from dataclasses import dataclass, fields
cyanguwa's avatar
cyanguwa committed
16
import numpy as np
17
from packaging.version import Version as PkgVersion
18
19

import torch
20
import torch.nn.functional as F
21

22
import transformer_engine_torch as tex
23
24
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
25
26
27
28
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
29
30
31
32
33
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
34
35
    fused_attn_fwd,
    fused_attn_bwd,
36
37
38
39
40
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
41
42
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
43
from transformer_engine.pytorch.module import LayerNormLinear, Linear
44
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
45
46
47
48
49
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
50
    get_default_init_method,
51
52
53
54
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
55
    AttnBiasTypes,
56
    QKVLayouts,
57
    dist_group_type,
58
    TE_DType,
59
60
61
62
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
63
    get_distributed_rank,
64
    checkpoint,
65
66
67
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
68
69
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
70
71
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
72
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
73
74
from transformer_engine.pytorch.graph import is_graph_capturing

75

76
77
78
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
79
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
80
81
82
83
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
84
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
85

86
if _flash_attn_version >= _flash_attn_version_required:
87
88
89
90
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
91

92
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
93
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
94
95
96
97
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
98
99
100
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
101

102
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
103
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
104
105
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
106
107
108
109
110
111
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
112

113
114
115
116
117
118
119
120
121
122
123
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))

_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
124
}
125
126


127
128
@dataclass(eq=True)
class AttentionParams:
129
    """
130
    Attention parameters used to determine which backend to be used.
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    Parameters
    ----------
    qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
        Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
    qkv_dtype: torch.dtype, default = `torch.bfloat16`
        Data type of query/key/value tensors.
    qkv_layout: str, default = "sbh3d"
        Query/key/value tensor memory layout.
    batch_size: int, default = 1
        Batch size.
    num_heads: int, default = 16
        Number of attention heads in the query tensor.
    num_gqa_groups: int, default = 16
        Number of attention heads in key and value tensors.
    max_seqlen_q: int, default = 128
        Maximum sequence length of the query tensor.
    max_seqlen_kv: int, default = 128
        Maximum sequence length of the key and value tensors.
150
151
152
153
    head_dim_qk: int, default = 64
        The size of each attention head in query and key tensors.
    head_dim_v: int, default = 64
        The size of each attention head in the value tensor.
154
155
156
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
        `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
157
    window_size: Tuple[int, int], default = None
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        Sliding window attention size.
    alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
        Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
    core_attention_bias_type: str, default = `no_bias`
        Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
    core_attention_bias_shape: str, default = `1hss`
        Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
    core_attention_bias_requires_grad: bool, default = `True`
        Whether attention bias requires gradient.
    pad_between_seqs: bool, default = `False`
        Whether there is padding between sequences in a batch.
        This only applies to `qkv_format=thd`.
    attention_dropout: float, default = 0.0
        Attention dropout.
    context_parallel: bool, default = `False`
        Whether context parallelism is used or not.
    deterministic: bool, default = `False`
        Whether to run `DotProductAttention` with determinism or not.
176
177
    is_training: bool, default = `True`
        Whether in training mode (`True`) or inference mode (`False`)
178
179
180
181
    fp8: bool, default = `False`
        Whether `DotProductAttention` is in an `fp8_autocast` region.
    fp8_meta: Optional[Dict[str Any]], default = `None`
        The FP8 metadata tensor of `DotProductAttention`.
182
183
184
185
186
187
188
189
190
191
    """

    qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
    qkv_dtype: torch.dtype = torch.bfloat16
    qkv_layout: str = "sbh3d"
    batch_size: int = 1
    num_heads: int = 16
    num_gqa_groups: int = 16
    max_seqlen_q: int = 128
    max_seqlen_kv: int = 128
192
193
    head_dim_qk: int = 64
    head_dim_v: int = 64
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    attn_mask_type: str = "no_mask"
    window_size: Union[Tuple[int, int], None] = None
    alibi_slopes_shape: Union[torch.Size, List, None] = None
    core_attention_bias_type: str = "no_bias"
    core_attention_bias_shape: str = "1hss"
    core_attention_bias_requires_grad: bool = True
    pad_between_seqs: bool = False
    attention_dropout: float = 0.0
    context_parallel: bool = False
    deterministic: bool = False
    is_training: bool = True
    fp8: bool = False
    fp8_meta: Union[Dict[str, Any], None] = None


_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


def get_attention_backend(
    attention_params: AttentionParams = None,
):
    """
    Select the appropriate attention backend/sub-backend based on user input and runtime environment.

    Parameters
    ----------
    See `AttentionParams`.
233
234
235
236
237
238
239

    Returns
    ----------
    use_flash_attention: bool
        Whether the `FlashAttention` backend has been selected.
    use_fused_attention: bool
        Whether the `FusedAttention` backend has been selected.
240
241
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend
        If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
242
243
244
245
246
247
    use_unfused_attention: bool
        Whether the `UnfusedDotProductAttention` backend has been selected.
    available_backends: List[bool]
        All available backends that could support the provided input. A list of Booleans
        in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
    """
248
249
250
251
252
253
254
255
    qkv_type = attention_params.qkv_type
    qkv_dtype = attention_params.qkv_dtype
    qkv_layout = attention_params.qkv_layout
    batch_size = attention_params.batch_size
    num_heads = attention_params.num_heads
    num_gqa_groups = attention_params.num_gqa_groups
    max_seqlen_q = attention_params.max_seqlen_q
    max_seqlen_kv = attention_params.max_seqlen_kv
256
257
    head_dim_qk = attention_params.head_dim_qk
    head_dim_v = attention_params.head_dim_v
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    attn_mask_type = attention_params.attn_mask_type
    window_size = attention_params.window_size
    alibi_slopes_shape = attention_params.alibi_slopes_shape
    core_attention_bias_type = attention_params.core_attention_bias_type
    core_attention_bias_shape = attention_params.core_attention_bias_shape
    core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
    pad_between_seqs = attention_params.pad_between_seqs
    attention_dropout = attention_params.attention_dropout
    context_parallel = attention_params.context_parallel
    deterministic = attention_params.deterministic
    is_training = attention_params.is_training
    fp8 = attention_params.fp8
    fp8_meta = attention_params.fp8_meta

    # Run config
273
    logger = logging.getLogger("DotProductAttention")
274
275
276
    logger.setLevel(_log_level)
    if not logger.hasHandlers():
        logger.addHandler(_stream_handler)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    device_compute_capability = get_device_compute_capability()
    cudnn_version = get_cudnn_version()
    run_config = {
        "transformer_engine_version": te.__version__,
        "compute_capability": "sm"
        + str(
            (lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
        ),
        "flash_attn_version": _flash_attn_version,
        "cudnn_version": ".".join([str(i) for i in cudnn_version]),
    }
    attention_params_dict = {
        field.name: getattr(attention_params, field.name) for field in fields(attention_params)
    }
    run_config.update(attention_params_dict)
    if fp8:
        run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
    logger.debug("Running with config=%s", run_config)
295
296

    # Filter: Environment variables
297
298
299
300
301
302
303
    global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN
    _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
    _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
    _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
    use_flash_attention = _NVTE_FLASH_ATTN
    use_fused_attention = _NVTE_FUSED_ATTN
    use_unfused_attention = _NVTE_UNFUSED_ATTN
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    if not use_flash_attention:
        logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
    if not use_fused_attention:
        logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
    if not use_unfused_attention:
        logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")

    # Filter: ONNX mode
    if is_in_onnx_export_mode():
        if use_flash_attention:
            logger.debug("Disabling FlashAttention due to ONNX mode")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention due to ONNX mode")
        use_fused_attention = False

    # Filter: Compute capability
    if device_compute_capability < (8, 0):
        if use_flash_attention:
            logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
            use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
            use_fused_attention = False

    # Filter: Data type
    if use_flash_attention and (
        qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor
    ):
        logger.debug(
            "Disabling FlashAttention due to unsupported QKV data type. "
            "Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. "
            "Found: qkv_type = %s, qkv_dtype = %s.",
            qkv_type,
            qkv_dtype,
        )
        use_flash_attention = False
    if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]):
        logger.debug(
            "Disabling FusedAttention due to unsupported QKV data type. "
            "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
            "Found: qkv_dtype = %s.",
            qkv_dtype,
        )
        use_fused_attention = False

    # Filter: Execution type
    if fp8 and fp8_meta["recipe"].fp8_dpa:
        if use_flash_attention:
            logger.debug("Disabling FlashAttention as it does not support FP8")
            use_flash_attention = False
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
            use_unfused_attention = False

    # Filter: Head dimension
360
361
362
    if use_flash_attention and head_dim_qk != head_dim_v:
        logger.debug("Disabling FlashAttention as it does not support MLA.")
        use_flash_attention = False
363
    if use_flash_attention and (
364
365
366
        head_dim_qk > 256
        or head_dim_qk % 8 != 0
        or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
367
368
    ):
        logger.debug(
369
370
371
372
373
374
            "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
            "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
            "head_dim_qk <= 256 (>192 requires sm80/90). "
            "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
            head_dim_qk,
            head_dim_v,
375
376
377
            ".".join([str(i) for i in device_compute_capability]),
        )
        use_flash_attention = False
378
379
380
381
382
383
384
    qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
    if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
        logger.debug(
            "Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
            qkv_layout,
        )
        use_fused_attention = False
385
386
387
388
389
390
391
392
393
394
395
396
397
398

    # Filter: QKV layout
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
            use_unfused_attention = False
        if use_flash_attention and pad_between_seqs:
            logger.debug(
                "Disabling FlashAttention for qkv_format = thd when there is "
                "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
            )
            use_flash_attention = False

399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    # Filter: Context parallelism
    # qkv_format | attn_mask_type              | attn_bias_type           | supported backends
    # ----------------------------------------------------------------------------------------------------
    # bshd, sbhd | self-attention:             | no_bias, post_scale_bias | FlashAttention, FusedAttention
    #            |     no_mask, causal         |                          |
    #            | cross-attention:            |                          |
    #            |     no_mask                 |                          |
    # thd        | self-attention:             | no_bias                  | FlashAttention, FusedAttention
    #            |     padding, padding_causal |                          | if no padding between sequences,
    #            | cross-attention:            |                          | FusedAttention
    #            |     padding                 |                          | if there is padding between sequences
    # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
    if context_parallel and use_unfused_attention:
        logger.debug(
            "Disabling UnfusedDotProductAttention as it does not support context parallelism"
        )
        use_unfused_attention = False
    if context_parallel and use_flash_attention:
        if "bottom_right" in attn_mask_type:
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with"
                " causal_bottom_right masking"
            )
            use_flash_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with causal"
                " masking for cross-attention"
            )
            use_flash_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with bias type"
                " of %s",
                core_attention_bias_type,
            )
            use_flash_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with attention"
                " bias for THD format"
            )
            use_flash_attention = False
    if context_parallel and use_fused_attention:
        if "bottom_right" in attn_mask_type:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with"
                " causal_bottom_right masking"
            )
            use_fused_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with causal"
                " masking for cross-attention"
            )
            use_fused_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with bias type"
                " of %s",
                core_attention_bias_type,
            )
            use_fused_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with attention"
                " bias for THD format"
            )
            use_fused_attention = False
        elif head_dim_qk != head_dim_v:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with MLA"
            )
            use_fused_attention = False

474
    # Filter: Attention mask
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    # attn_mask_type              | attention_mask                       | supported backends
    # ----------------------------------------------------------------------------------------
    # no_mask                     | None                                 | All
    # padding                     |                                      | All
    #     self-attention          | One tensor in shape [b, 1, 1, sq]    |
    #     cross-attention         | Tuple of two tensors in shapes       |
    #                             | [b, 1, 1, sq] and [b, 1, 1, skv]     |
    # causal                      | None                                 |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # padding_causal              | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # causal_bottom_right         | None                                 | All
    # padding_causal_bottom_right | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FlashAttention, UnfusedDotProductAttention
    # arbitrary                   | One tensor in shape broadcastable to | UnfusedDotProductAttention
    #                             | [b, h, sq, skv]                      |
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    if attn_mask_type == "arbitrary":
        if use_flash_attention:
            logger.debug("Disabling FlashAttention for arbitrary mask")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention for arbitrary mask")
        use_fused_attention = False
    if (
        use_flash_attention
        and _flash_attn_2_1_plus
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        use_flash_attention = False
    if (
        use_flash_attention
        and not _flash_attn_2_1_plus
        and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention as it only supports top-left-diagonal "
            "causal mask before flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        use_flash_attention = False

    # Filter: Sliding window attention
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    #    backend                 |      window_size       | diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | (-1, -1) or (>=0, >=0) | bottom right
    # FusedAttention             | (-1,  0) or (>=0, 0)   | top left
    # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
    #                            |                        | converts window_size to an 'arbitrary' mask
    if window_size is None:
        window_size = check_set_window_size(attn_mask_type, window_size)
    else:
        if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention"
                    " for FP8"
                )
                use_fused_attention = False
            elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
                logger.debug(
                    "Disabling FusedAttention as it only supports sliding window attention "
                    "with causal mask, no dropout, and qkv_format = bshd/sbhd"
                )
                use_fused_attention = False
            elif context_parallel:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with context parallelism"
                )
                use_fused_attention = False
            elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
                "no_mask",
                "padding",
                "causal_bottom_right",
                "padding_causal_bottom_right",
            ]:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s for cross-attention",
                    attn_mask_type,
                )
                use_fused_attention = False
            elif "padding" in attn_mask_type:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s",
                    attn_mask_type,
                )
                use_fused_attention = False
        if (
            use_flash_attention
            and (window_size[0] != -1 or window_size[1] not in [-1, 0])
577
            and not _flash_attn_2_3_plus
578
        ):
579
            logger.debug(
580
                "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
581
582
583
584
            )
            use_flash_attention = False

    # Filter: Attention bias
585
586
587
588
589
590
591
592
    #    backend                 |      bias types              | ALiBi diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | no_bias, alibi/alibi_slopes  | bottom right
    # FusedAttention             | no_bias, post_scale_bias     |
    #                            | alibi/alibi_slopes           | top left,
    #                            |                              | bottom_right (converts to a 'post_scale_bias' bias)
    # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
    #                            | alibi/alibi_slopes           | both; converts to a 'post_scale_bias' bias
593
594
595
596
597
598
599
600
601
602
603
604
605
    if use_flash_attention and (
        core_attention_bias_type not in ["no_bias", "alibi"]
        or core_attention_bias_shape is not None
    ):
        logger.debug("Disabling FlashAttention for pre/post_scale_bias")
        use_flash_attention = False

    fu_core_attention_bias_type = core_attention_bias_type
    fu_core_attention_bias_shape = core_attention_bias_shape
    fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
    if (
        use_fused_attention
        and core_attention_bias_type == "alibi"
606
        and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
607
608
609
    ):
        fu_core_attention_bias_type = "post_scale_bias"
        fu_core_attention_bias_requires_grad = False
610
611
612
613
614
        if alibi_slopes_shape is None:
            fu_core_attention_bias_shape = "1hss"
        elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
            fu_core_attention_bias_shape = "1hss"
        elif (
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
            len(alibi_slopes_shape) == 2
            and alibi_slopes_shape[0] == batch_size
            and alibi_slopes_shape[1] == num_heads
        ):
            fu_core_attention_bias_shape = "bhss"

    if (
        use_fused_attention
        and fu_core_attention_bias_type == "post_scale_bias"
        and fu_core_attention_bias_shape != "1hss"
    ):
        if fu_core_attention_bias_requires_grad:
            # remove this line when cuDNN adds bwd support for
            # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
            logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
            use_fused_attention = False
        else:
            # max512 backend will only support [1, h, s, s]
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

    # Filter: cuDNN support
    fused_attention_backend = None
    if use_fused_attention:
        q_type = TE_DType[qkv_dtype]
        kv_type = q_type
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            kv_type = q_type
        fused_attention_backend = tex.get_fused_attn_backend(
            q_type,
            kv_type,
            QKVLayout[qkv_layout],
            AttnBiasType[fu_core_attention_bias_type],
            AttnMaskType[attn_mask_type],
            attention_dropout,
            num_heads,
            num_gqa_groups,
            max_seqlen_q,
            max_seqlen_kv,
654
655
            head_dim_qk,
            head_dim_v,
656
657
            window_size[0],
            window_size[1],
658
        )
659
        if fused_attention_backend == FusedAttnBackend["No_Backend"]:
660
661
            logger.debug("Disabling FusedAttention as no backend supports the provided input")
            use_fused_attention = False
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
            fused_attention_backend = None
        if (
            use_fused_attention
            and window_size is not None
            and window_size[0] != -1
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "slidng window attention",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
679
680
681
682
683
684
685
686
            and fu_core_attention_bias_type == "post_scale_bias"
            and fu_core_attention_bias_shape != "1hss"
        ):
            logger.debug(
                "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
                " [1, H, S, S] shape"
            )
            use_fused_attention = False
687
            fused_attention_backend = None
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707

    # Filter: Determinism
    # backend                      | deterministic
    # ---------------------------------------------
    # FlashAttention               |
    #     flash-attn >=2.0, <2.4.1 | no
    #     flash-attn >=2.4.1       | yes
    # FusedAttention               |
    #     sub-backend 0            | yes
    #     sub-backend 1            | workspace optimization path and sm90+: yes;
    #                              | otherwise: no
    #     sub-backend 2            | no
    # UnfusedDotProductAttention   | yes
    if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus:
        logger.warning(
            "Disabling FlashAttention as version <2.4.1 does not support deterministic "
            "execution. To use FlashAttention with deterministic behavior, "
            "please install flash-attn >= 2.4.1."
        )
        use_flash_attention = False
708
709
710
711
712
713
714
715
716
717
718
    if use_fused_attention and deterministic:
        if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (
                device_compute_capability < (9, 0)
                or core_attention_bias_requires_grad
                or cudnn_version < (8, 9, 5)
719
            )
720
721
722
        ):
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
723
724
725

    # All available backends
    available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
726
727
728
729
730
731
732
733
734
735
736
737
    logger.debug(
        "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
        " UnfusedDotProductAttention=%s}",
        bool(available_backends[0]),
        bool(available_backends[1]),
        (
            f" (sub-backend {int(fused_attention_backend)})"
            if fused_attention_backend is not None
            else ""
        ),
        bool(available_backends[2]),
    )
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757

    # Select FusedAttention for performance
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
    ):
        if device_compute_capability == (9, 0):
            logger.debug(
                "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                "for performance reasons"
            )
            use_flash_attention = False

    # Selected backend
    if use_flash_attention:
        use_fused_attention = False
        use_unfused_attention = False
    elif use_fused_attention:
        use_unfused_attention = False
758
    selected_backend = "NoBackend"
759
760
761
762
763
764
    if use_flash_attention:
        selected_backend = "FlashAttention"
    elif use_fused_attention:
        selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
    elif use_unfused_attention:
        selected_backend = "UnfusedDotProductAttention"
765
    logger.debug("Selected backend = %s", selected_backend)
766

767
768
769
770
771
772
    global _attention_backends
    _attention_backends["use_flash_attention"] = use_flash_attention
    _attention_backends["use_fused_attention"] = use_fused_attention
    _attention_backends["fused_attention_backend"] = fused_attention_backend
    _attention_backends["use_unfused_attention"] = use_unfused_attention
    _attention_backends["backend_selection_requires_update"] = False
773
774
775
776

    return (
        use_flash_attention,
        use_fused_attention,
777
        fused_attention_backend,
778
779
780
781
782
        use_unfused_attention,
        available_backends,
    )


783
class InferenceParams:  # pylint: disable=too-few-public-methods
784
785
    """
    Inference parameters that are passed to the main model in order
786
    to efficiently calculate and store the context during inference.
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
827

828

829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
@torch.no_grad()
def get_swa_mask(
    window_size: Tuple[int, int],
    max_seqlen_q: int,
    max_seqlen_kv: int,
    attn_mask_type: str = "no_mask",
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
    """
    Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
    For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
    and for other mask types, the bottom right corner.

    Parameters
    ----------
    window_size: Tuple[int, int]
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
        window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
        map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
        `attn_mask_type`.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
        "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
        default = `None`
        Boolean tensor(s) used to mask out attention softmax input.

    Returns
    ----------
    attention_mask: torch.Tensor
        Combined `attention_mask` (input) and sliding window attention mask.
        The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
        else, the same shape as input `attention_mask`.
    """
    mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
    if attn_mask_type in ["causal"]:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_q
        right = window_size[1] if window_size[1] != -1 else max_seqlen_q
        mask_upper = torch.triu(mask, diagonal=-left)
        mask_lower = torch.tril(mask_upper, diagonal=right)
    else:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
        right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
        mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
        mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
    attn_mask_type = "arbitrary"
    mask = mask_lower.logical_not()
    if attention_mask is not None:
        mask = torch.logical_and(attention_mask, mask)
    return attn_mask_type, mask


887
888
889
890
891
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
892
893
    actual_seqlens_q: Optional[torch.Tensor] = None,
    actual_seqlens_kv: Optional[torch.Tensor] = None,
894
895
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
896
    bottom_right_alignment: bool = True,
897
) -> Tuple[torch.Tensor, torch.Tensor]:
898
    """
899
900
901
902
903
904
905
906
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
907
908
909
910
    actual_seqlens_q: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for queries, in shape [batch_size].
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for keys and values, in shape [batch_size].
911
912
913
914
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
915
916
917
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the ALiBi bias to the bottom right corner of
        the matrix (`True`) or top left (`False`).
918

919
920
921
922
923
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
924
925
926
927
928
929
        ALiBi bias in FP32 or `bias_dtype`. Its shape is
        (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
        and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
        (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
        [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
        `actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
        if _alibi_cache["_alibi_slopes"].dim() == 2:
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
955
        bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
956
            1, 1, max_seqlen_q, 1
957
958
        ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
            1, 1, 1, max_seqlen_kv
959
        )
960
961
962
963
964
965
966
967
968
969
970
971
        if actual_seqlens_q is None and actual_seqlens_kv is None:
            if bottom_right_alignment:
                bias = bias + max_seqlen_kv - max_seqlen_q
        elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
            batch_size = actual_seqlens_q.shape[0]
            bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
            if bottom_right_alignment:
                bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
        else:
            assert (
                False
            ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
972
973
974
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
975
        _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
976
977
978
979
980
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
981
982
983
984
985
986
987
988
989


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
990
    reduced_mask = mask.logical_not().sum(dim=1)
991
992
993
994
995
996
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

997

998
999
1000
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
1001
1002
1003
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
1004
1005
1006
1007
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

1008
    reduced_mask = mask.logical_not().sum(dim=1)
1009
1010
1011
1012
1013
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
1014
    indices = mask.logical_not().nonzero()
1015
1016
1017
1018
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
1019
1020
1021
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
1022
1023
1024
1025

    return cu_seqlens, indices


1026
1027
1028
1029
1030
1031
1032
1033
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
1034
1035
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
1036
1037
1038

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
1039
1040
1041
1042
1043
1044
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
1045
1046
1047

    return indices

1048

1049
_cu_seqlens_cache = {}
1050
1051


1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
1072
1073


1074
1075
1076
1077
1078
1079
1080
1081
1082
@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
1083
1084
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
1133
1134
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1135
    unpacked.scatter_(0, indices, tensor)
1136
    unpacked = unpacked[0:-1, :, :]
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
1176

1177
1178
    @staticmethod
    def forward(
1179
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
1180
1181
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
1182
        ctx.save_for_backward(indices)
1183
1184
1185
1186
1187
1188
1189
1190
1191
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
1192
        (indices,) = ctx.saved_tensors
1193
        if len(grad_outputs) == 1:
1194
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
1195
        if len(grad_outputs) == 2:
1196
1197
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
1198
1199
1200
1201
1202
1203


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
1204

1205
1206
1207
1208
1209
1210
1211
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
1212
        ctx.save_for_backward(indices)
1213
1214
1215
1216
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
1217
1218
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
1219
1220


1221
1222
1223
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
1224
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
1225
1226
1227
1228
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
1229
1230
1231
1232
1233
1234
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
1235
1236
1237
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
1238
1239
1240
1241
1242
1243
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


1263
@jit_fuser
1264
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
1265
    """Merge partial outputs of each step in Attention with context parallelism"""
1266
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
1267
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
1268
    out_corrected = out_per_step * softmax_lse_corrected_exp
1269
1270
1271
    out.add_(out_corrected)


1272
@jit_fuser
1273
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
1274
    """Merge softmax stats of each step in Attention with context parallelism"""
1275
1276
1277
1278
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
1279
1280


1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
@jit_fuser
def get_cu_seqlens_on_cp_rank(
    cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


1302
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
1303
    """
1304
1305
1306
    Attention implementation with context parallelism. Exchange KV between CP ranks
    with P2P in ring topology. Split attention compute into multiple steps, and overlap
    current-step compute with next-step communication.
1307
1308
1309
    """

    @staticmethod
1310
1311
1312
1313
1314
1315
1316
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
1317
        cu_seqlens_kv,
1318
        max_seqlen_q,
1319
        max_seqlen_kv,
1320
1321
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
        dropout_p,
        cp_group,
        cp_global_ranks,
        cp_stream,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
1333
1334
        fp8,
        fp8_meta,
1335
    ):
1336
1337
1338
1339
1340
1341
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
1342
        recv_src = cp_global_ranks[(rank - 1) % cp_size]
1343
1344
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1345
1346
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
1347

1348
        if qkv_format in ["bshd", "sbhd"]:
1349
            seq_dim = qkv_format.index("s")
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
        pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
        cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
        cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
1362

1363
1364
1365
        assert qkv_format == "thd" or (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"
1366
        if causal:
1367
1368
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
1369
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
1370
1371
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
1372
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
1373
1374
1375
        total_tokens_kv = None if qkv_format != "thd" else k.shape[0]
        # remove padded tokens at the end
        k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]]
1376
        if attn_bias is not None:
1377
            assert len(attn_bias.shape) == 4, (
1378
1379
1380
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
1381
1382
1383
            assert (
                attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
            ), "Sequence length does not meet divisible requirements!"
1384
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
1385
1386
1387
1388
1389
1390
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
1391
1392
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
1393
1394
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
1395
            )
1396
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
1397
1398
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
1399
            fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
1400
1401
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
1402
1403
        if _flash_attn_2_5_7_plus:
            fa_optional_forward_kwargs["block_table"] = None
1404

1405
1406
1407
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
1408
        attn_bias_inputs = [None, None]
1409
1410
1411
1412
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
1413
        attn_biases = [None for _ in range(cp_size)]
1414
1415
1416
1417
1418
1419

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
        if fp8:
            if use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_backend = FusedAttnBackend["FP8"]
                if fp8_meta["recipe"].fp8_mha:
                    assert (
                        isinstance(q, Float8Tensor)
                        and isinstance(k, Float8Tensor)
                        and isinstance(v, Float8Tensor)
                    ), "q/k/v must be Float8Tensors for FP8 MHA!"
                    fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                else:
                    q_f16, k_f16, v_f16 = q, k, v
                    q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                    if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                        k, v = [
                            cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                            for x in [k_f16, v_f16]
                        ]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV]
                fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S]
                fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S]
                fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP]
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            q_f16 = q
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

1457
        p2p_comm_buffers = [None for _ in range(cp_size)]
1458
1459
1460
1461
        if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
1462
1463
        send_recv_reqs = [[], []]

1464
        for i in range(cp_size + 1):
1465
            if i < cp_size:
1466
                with torch.cuda.stream(flash_attn_streams[i % 2]):
1467
                    # wait until KV is received
1468
                    for req in send_recv_reqs[(i + 1) % 2]:
1469
1470
                        req.wait()

1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
                    if (
                        not fp8
                        or fp8_meta["recipe"].fp8_mha
                        or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
                    ):
                        kv_inputs[i % 2] = p2p_comm_buffers[i]
                    else:
                        # KV exchange is in BF16/FP16, cast received KV in each step
                        kv_inputs[i % 2] = cast_to_fp8(
                            p2p_comm_buffers[i],
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                        )
                    if fp8 and use_fused_attention:
                        fp8_meta_kwargs["amax_s"] = amax_per_step[0][i]
                        fp8_meta_kwargs["amax_o"] = amax_per_step[1][i]
1500
1501
                    if causal:
                        if i == 0:
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1514
                            if use_fused_attention:
1515
1516
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1517
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1518
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
1519
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
1520
                                        k.shape[0], -1, 2, *k.shape[-2:]
1521
                                    )
1522
1523
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1524
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1525
1526
1527
1528
                                    # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        -1, k.shape[2], 2, *k.shape[-2:]
                                    )
1529
                                elif qkv_format == "thd":
1530
                                    q_inputs[i % 2] = q
1531
1532
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1533
1534
1535
1536
1537
1538
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1539
                                    ).contiguous()
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
1568
                                )
1569
1570
1571
1572
1573
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1574
1575
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1576
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1577
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1592
1593
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1594
                                    max_seqlen_q,
1595
                                    max_seqlen_kv,
1596
1597
1598
1599
1600
                                    dropout_p,
                                    softmax_scale,
                                    causal=True,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1601
                                )
1602
                        elif i <= rank:
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
1620
                            if use_fused_attention:
1621
1622
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1623
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1624
1625
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
1626
1627
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1628
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1629
1630
                                    # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous()
1631
                                elif qkv_format == "thd":
1632
                                    q_inputs[i % 2] = q
1633
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
1634
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
1635
                                        kv_inputs[i % 2], cu_seqlens_kv_padded, 0
1636
                                    )
1637
1638
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1639
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv // 2,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=(
                                        None
                                        if cu_seqlens_kv_padded is None
                                        else cu_seqlens_kv_padded // 2
                                    ),
                                    **fp8_meta_kwargs,
1672
                                )
1673
1674
1675
1676
1677
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1678
1679
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1680
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1681
1682
                                if qkv_format == "thd":
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
1683
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
1684
                                        kv_inputs[i % 2], cu_seqlens_kv_padded, 0
1685
                                    )
1686
1687
                                else:
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
1688
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
1689
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
1690
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
1691
                                if _flash_attn_2_3_plus:
1692
                                    fa_optional_forward_kwargs["window_size"] = (-1, -1)
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1706
1707
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1708
                                    max_seqlen_q,
1709
                                    max_seqlen_kv // 2,
1710
1711
1712
1713
1714
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1715
1716
                                )
                        else:
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1734
                            if use_fused_attention:
1735
1736
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
1737
                                    q_inputs[i % 2] = q[:, 1, ...].contiguous()
1738
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
1739
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
1740
                                        k.shape[0], -1, 2, *k.shape[-2:]
1741
                                    )
1742
1743
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
1744
                                    q_inputs[i % 2] = q[1].contiguous()
1745
1746
1747
1748
                                    # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        -1, k.shape[2], 2, *k.shape[-2:]
                                    )
1749
1750
                                elif qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
1751
1752
1753
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(
                                        q, cu_seqlens_q_padded, 1
                                    )
1754
1755
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1756
1757
1758
1759
1760
1761
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1762
                                    ).contiguous()
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q // 2,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=(
                                        None
                                        if cu_seqlens_q_padded is None
                                        else cu_seqlens_q_padded // 2
                                    ),
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
1795
                                )
1796
1797
1798
1799
1800
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1801
                            else:
1802
1803
                                if qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
1804
1805
1806
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(
                                        q, cu_seqlens_q_padded, 1
                                    )
1807
1808
                                else:
                                    # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
1809
                                    q_inputs[i % 2] = (
1810
                                        q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
1811
                                    )
1812
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
1813
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
1814
                                if _flash_attn_2_3_plus:
1815
                                    fa_optional_forward_kwargs["window_size"] = (-1, -1)
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1829
1830
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1831
                                    max_seqlen_q // 2,
1832
                                    max_seqlen_kv,
1833
1834
1835
1836
1837
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1838
1839
                                )
                    else:
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
                        if pad_between_seqs_q:
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
                        else:
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                        if pad_between_seqs_kv:
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
                        else:
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1857
                        if use_fused_attention:
1858
1859
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
1860
1861
1862
1863
1864
1865
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
1866
                                ).contiguous()
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
                            out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                is_training,
                                max_seqlen_q,
                                max_seqlen_kv,
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                fused_attn_qkv_dtype,
                                fused_attn_backend,
                                attn_scale=softmax_scale,
                                dropout=dropout_p,
                                qkv_layout=qkv_layout,
                                attn_mask_type=attn_mask_type,
                                attn_bias_type=attn_bias_type,
                                attn_bias=attn_bias_inputs[i % 2],
                                cu_seqlens_q_padded=cu_seqlens_q_padded,
                                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                **fp8_meta_kwargs,
1895
                            )
1896
1897
1898
1899
1900
                            if fp8:
                                softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                            else:
                                softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                attn_biases[i] = rest[0] if len(rest) > 0 else None
1901
                        else:
1902
                            # [b, sq, np, hn] -> [b*sq, np, hn]
1903
                            q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1904
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
                            kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                            (
                                _,
                                _,
                                _,
                                _,
                                out_per_step[i],
                                softmax_lse_per_step[i],
                                _,
                                rng_states[i],
                            ) = _flash_attn_forward(
                                q_inputs[i % 2],
                                kv_inputs[i % 2][0],
                                kv_inputs[i % 2][1],
1919
1920
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
1921
                                max_seqlen_q,
1922
                                max_seqlen_kv,
1923
1924
1925
1926
1927
                                dropout_p,
                                softmax_scale,
                                causal=False,
                                return_softmax=False,
                                **fa_optional_forward_kwargs,
1928
                            )
1929
1930
1931
1932

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
1933
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
1934

1935
1936
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
1937
                    softmax_lse_per_step[i - 1].squeeze_(-1)
1938

1939
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
1940
1941
1942
1943
1944
1945
1946
1947
                    if fp8:
                        out_per_step[i - 1] = cast_from_fp8(
                            out_per_step[i - 1],
                            fp8_meta["scaling_fwd"],
                            META_O_CP,
                            fp8_dtype_forward,
                            TE_DType[torch.float32],
                        )
1948
                    if i == 1:
1949
                        out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
1950
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
1951
                        if causal and qkv_format != "thd":
1952
1953
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
1954
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
1955
                            )
1956
1957
1958
1959
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
1960
                    else:
1961
                        if qkv_format == "thd":
1962
                            tex.thd_second_half_lse_correction(
1963
1964
1965
1966
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
                                max_seqlen_q,
1967
                            )
1968
                        else:
1969
1970
1971
                            flash_attn_fwd_softmax_lse_correction(
                                softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
                            )
1972
1973

                if i < cp_size:
1974
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
1975
1976
1977
1978
1979

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
1980
1981
1982
1983
1984
1985
            if qkv_format == "bshd":
                out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
                out_ = out[:, 1, ...]
            elif qkv_format == "sbhd":
                out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
                out_ = out[1]
1986

1987
            if i <= rank or not causal:
1988
                if qkv_format in ["bshd", "sbhd"]:
1989
1990
1991
1992
1993
1994
1995
                    flash_attn_fwd_out_correction(
                        out.view(*out_per_step[i].shape),
                        out_per_step[i],
                        seq_dim,
                        softmax_lse,
                        softmax_lse_per_step[i],
                    )
1996
                elif qkv_format == "thd":
1997
1998
1999
2000
2001
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2002
                        cu_seqlens_q_padded,
2003
2004
                        False,
                    )
2005
            else:
2006
                if qkv_format in ["bshd", "sbhd"]:
2007
2008
2009
2010
2011
2012
2013
                    flash_attn_fwd_out_correction(
                        out_,
                        out_per_step[i],
                        seq_dim,
                        softmax_lse_[..., 1, :],
                        softmax_lse_per_step[i],
                    )
2014
                elif qkv_format == "thd":
2015
2016
2017
2018
2019
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2020
                        cu_seqlens_q_padded,
2021
2022
                        True,
                    )
2023
2024

        kv = p2p_comm_buffers[-1]
2025
        if use_fused_attention:
2026
2027
2028
2029
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
2030
2031
        else:
            out = out.view(-1, *out.shape[-2:])
2032

2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
        if fp8 and use_fused_attention:
            amax_cp_fwd = amax_per_step.amax(dim=1)
            fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0]
            fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]

        out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype)
        if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
            out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)

        if fp8 and fp8_meta["recipe"].fp8_mha:
            out_ret = Float8Tensor(
                data=out_fp8,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_O,
                fp8_dtype=fp8_dtype_forward,
                dtype=q_fp8.dtype,
            )
        else:
            out_ret = out_f16

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            q_save, kv_save, out_save = q, kv, out_fp8
            fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
            fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
        elif fp8 and fp8_meta["recipe"].fp8_mha:
            kv_fp8 = Float8Tensor(
                data=kv,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_QKV,
                fp8_dtype=fp8_dtype_forward,
                dtype=k_fp8.dtype,
            )
            q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None
        else:
            q_save, kv_save, out_save = q_f16, kv, out_f16
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None

2073
        ctx.save_for_backward(
2074
2075
2076
            q_save,
            kv_save,
            out_save,
2077
            softmax_lse,
2078
2079
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
2080
2081
            fp8_fwd_scales,
            fp8_fwd_scale_invs,
2082
2083
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
2084
2085
            *rng_states,
            *attn_biases,
2086
        )
2087
2088
2089
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
2090
        ctx.total_tokens_kv = total_tokens_kv
2091
        ctx.max_seqlen_q = max_seqlen_q
2092
        ctx.max_seqlen_kv = max_seqlen_kv
2093
        ctx.softmax_scale = softmax_scale
2094
        ctx.qkv_format = qkv_format
2095
        ctx.attn_mask_type = attn_mask_type
2096
2097
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
2098
        ctx.deterministic = deterministic
2099
        ctx.use_fused_attention = use_fused_attention
2100
2101
2102
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
        return out_ret
2103
2104
2105
2106
2107

    @staticmethod
    def backward(ctx, dout):
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
2108
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
2109
2110
2111
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

2112
        (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
2113
2114
2115
2116
2117
        (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8]
        cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size]
        cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
        rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
        attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
2118

2119
2120
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
2121
2122
2123
2124
        if ctx.qkv_format in ["bshd", "sbhd"]:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
2125

2126
        if attn_biases[0] is not None:
2127
2128
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
2129
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
2130
2131
2132
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
2133
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
2134
2135
2136
2137
            )
        else:
            attn_dbias = None

2138
        if causal:
2139
            if ctx.qkv_format == "thd":
2140
2141
2142
                softmax_lse_ = tex.thd_read_second_half_lse(
                    softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q
                )
2143
2144
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
2145
2146
2147
                softmax_lse_ = softmax_lse.view(
                    *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
                )
2148
2149
2150
2151
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
                if ctx.use_fused_attention:
                    # [b, np, sq//2] -> [b, np, sq//2, 1]
                    softmax_lse_.unsqueeze_(-1)
2152
2153
2154
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203

        if ctx.fp8:
            if ctx.use_fused_attention:
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
                fused_attn_qkv_dtype = fp8_dtype_backward
                fused_attn_dqkv_dtype = fp8_dtype_backward
                fused_attn_backend = FusedAttnBackend["FP8"]
                dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
                dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
                dkv_fp8_ = torch.empty_like(dkv_fp8)
                dout_dtype = dout.dtype
                if ctx.fp8_meta["recipe"].fp8_mha:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
                    dout = dout._data
                else:
                    dout = cast_to_fp8(
                        dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                    )
                p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
                fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
                fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
                fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
                fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
                fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
                fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
                fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP]
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
                q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]]
            dq = torch.empty_like(q)
            if ctx.qkv_format == "thd" and causal:
                dq[cu_seqlens_q_padded[-1] :].fill_(0)
            p2p_comm_buffers = [
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            ]
            p2p_comm_buffers[0][0].copy_(kv)
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_dqkv_dtype = TE_DType[q.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

2204
2205
2206
2207
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        send_recv_reqs = []

2208
2209
2210
2211
2212
2213
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

2214
2215
2216
2217
2218
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

2219
2220
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
            if ctx.fp8:
                if i < cp_size - 1:
                    send_recv_reqs = flash_attn_p2p_communicate(
                        rank,
                        send_tensor[0],
                        send_dst,
                        recv_tensor[0],
                        recv_src,
                        ctx.cp_group,
                        batch_p2p_comm,
                    )
                else:
                    dkv_a2a_req = torch.distributed.all_to_all_single(
                        dkv_fp8,
                        dkv_fp8_,
                        group=ctx.cp_group,
                        async_op=True,
                    )
                    send_recv_reqs = [dkv_a2a_req]
            else:
                if i == 0:
                    send_tensor = send_tensor[0]
                    recv_tensor = recv_tensor[0]
                if i == (cp_size - 1):
                    send_tensor = send_tensor[1]
                    recv_tensor = recv_tensor[1]
                send_recv_reqs = flash_attn_p2p_communicate(
                    rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
                )
2250

2251
            kv = p2p_comm_buffers[i % 2][0]
2252
2253
2254
            if ctx.fp8 and ctx.use_fused_attention:
                fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
                fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
2255
            # In reversed order of fwd
2256
            if causal:
2257
                if i == (cp_size - 1):
2258
                    if ctx.use_fused_attention:
2259
2260
2261
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
2262
2263
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
2264
2265
2266
2267
2268
2269
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
2270
2271
                            # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                            kv_ = kv.view(-1, *kv.shape[-4:])
2272
2273
2274
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
2275
2276
                        elif ctx.qkv_format == "thd":
                            q_, kv_, out_, dout_ = q, kv, out, dout
2277
2278
2279
2280
2281
2282
2283
2284
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2285
                        if attn_dbias is not None:
2286
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2287
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2288
                            ctx.max_seqlen_q,
2289
2290
2291
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2292
                            q_,
2293
2294
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2295
2296
                            out_,
                            dout_,
2297
2298
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2299
                            aux_ctx_tensors,
2300
                            fused_attn_backend,
2301
2302
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2303
2304
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2305
                            qkv_layout=qkv_layout,
2306
                            attn_mask_type=ctx.attn_mask_type,
2307
                            attn_bias_type=ctx.attn_bias_type,
2308
2309
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2310
2311
2312
2313
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
2314
                        dq_ = torch.zeros_like(q_)
2315
2316
2317
2318
2319
2320
2321
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
2322
                            fa_optional_backward_kwargs["window_size"] = (-1, 0)
2323
                        _flash_attn_backward(
2324
2325
2326
2327
2328
2329
2330
2331
2332
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2333
2334
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2335
                            ctx.max_seqlen_q,
2336
                            ctx.max_seqlen_kv,
2337
2338
2339
2340
2341
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            True,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2342
                        )
2343
                elif i >= (cp_size - rank - 1):
2344
                    if ctx.use_fused_attention:
2345
2346
2347
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
2348
2349
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                            kv_ = kv[:, 0, ...].contiguous()
2350
2351
2352
2353
2354
2355
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
2356
2357
                            # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                            kv_ = kv[0].contiguous()
2358
2359
2360
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
2361
2362
2363
                        elif ctx.qkv_format == "thd":
                            q_, out_, dout_ = q, out, dout
                            # [2, t, np, hn] -> [2, t/2, np, hn]
2364
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2365
2366
2367
2368
2369
2370
2371
2372
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2373
                        if attn_dbias is not None:
2374
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2375
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2376
                            ctx.max_seqlen_q,
2377
2378
2379
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2380
                            q_,
2381
2382
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2383
2384
                            out_,
                            dout_,
2385
2386
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2387
                            aux_ctx_tensors,
2388
                            fused_attn_backend,
2389
2390
2391
2392
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
2393
2394
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2395
                            qkv_layout=qkv_layout,
2396
                            attn_mask_type="padding" if padding else "no_mask",
2397
                            attn_bias_type=ctx.attn_bias_type,
2398
2399
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2400
2401
2402
2403
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
2404
                        dq_ = torch.zeros_like(q_)
2405
2406
                        if ctx.qkv_format == "thd":
                            # [2, t, np, hn] -> [2, t/2, np, hn]
2407
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2408
2409
2410
                        else:
                            # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
2411
2412
2413
2414
2415
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
2416
                            fa_optional_backward_kwargs["window_size"] = (-1, -1)
2417
                        _flash_attn_backward(
2418
2419
2420
2421
2422
2423
2424
2425
2426
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2427
2428
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2429
                            ctx.max_seqlen_q,
2430
                            ctx.max_seqlen_kv // 2,
2431
2432
2433
2434
2435
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2436
2437
2438
                        )
                else:
                    if ctx.use_fused_attention:
2439
2440
2441
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous()
2442
2443
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
2444
2445
2446
2447
2448
2449
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous()
                            dout_ = dout[:, 1, ...].contiguous()
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            q_ = q[1].contiguous()
2450
2451
                            # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                            kv_ = kv.view(-1, *kv.shape[-4:])
2452
2453
2454
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            out_ = out[1].contiguous()
                            dout_ = dout[1].contiguous()
2455
2456
                        elif ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
2457
2458
2459
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
2460
                            kv_ = kv
2461
2462
2463
2464
2465
2466
2467
2468
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse_,
                                softmax_lse_,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
2469
                        if attn_dbias is not None:
2470
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2471
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2472
                            ctx.max_seqlen_q // 2,
2473
2474
2475
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2476
                            q_,
2477
2478
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2479
2480
                            out_,
                            dout_,
2481
2482
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2483
                            aux_ctx_tensors,
2484
                            fused_attn_backend,
2485
2486
2487
2488
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2489
2490
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2491
                            qkv_layout=qkv_layout,
2492
                            attn_mask_type="padding" if padding else "no_mask",
2493
                            attn_bias_type=ctx.attn_bias_type,
2494
2495
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2496
2497
                        )
                    else:
2498
2499
                        if ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
2500
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
2501
2502
2503
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
2504
                        dq_ = torch.zeros_like(q_)
2505
2506
2507
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
2508
                        if ctx.qkv_format == "thd":
2509
2510
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
2511
2512
2513
2514
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                            dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
2515
                        if _flash_attn_2_3_plus:
2516
                            fa_optional_backward_kwargs["window_size"] = (-1, -1)
2517
                        _flash_attn_backward(
2518
2519
2520
2521
2522
2523
2524
2525
2526
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse_,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2527
2528
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2529
                            ctx.max_seqlen_q // 2,
2530
                            ctx.max_seqlen_kv,
2531
2532
2533
2534
2535
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2536
2537
2538
                        )
            else:
                if ctx.use_fused_attention:
2539
2540
2541
2542
                    if ctx.fp8:
                        aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
                    else:
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2543
                    if attn_dbias is not None:
2544
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2545
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2546
                        ctx.max_seqlen_q,
2547
2548
2549
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
2550
                        q,
2551
2552
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
2553
2554
                        out,
                        dout,
2555
2556
                        fused_attn_qkv_dtype,
                        fused_attn_dqkv_dtype,
2557
                        aux_ctx_tensors,
2558
                        fused_attn_backend,
2559
2560
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2561
2562
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
2563
                        qkv_layout=qkv_layout,
2564
                        attn_mask_type=ctx.attn_mask_type,
2565
                        attn_bias_type=ctx.attn_bias_type,
2566
2567
                        deterministic=ctx.deterministic,
                        **fp8_meta_kwargs,
2568
2569
2570
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
2571
                    q_ = q.view(-1, *q.shape[-2:])
2572
                    dq_ = torch.zeros_like(q_)
2573
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
2574
2575
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
2576
                    # [b, sq, np, hn] -> [b*sq, np, hn]
2577
2578
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
2579
                    if _flash_attn_2_3_plus:
2580
                        fa_optional_backward_kwargs["window_size"] = (-1, -1)
2581
                    _flash_attn_backward(
2582
2583
2584
2585
2586
2587
2588
2589
2590
                        dout_,
                        q_,
                        kv_[0],
                        kv_[1],
                        out_,
                        softmax_lse,
                        dq_,
                        dkv_[0],
                        dkv_[1],
2591
2592
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
2593
                        ctx.max_seqlen_q,
2594
                        ctx.max_seqlen_kv,
2595
2596
2597
                        ctx.dropout_p,
                        ctx.softmax_scale,
                        False,
2598
                        rng_state=rng_states[cp_size - i - 1],
2599
                        **fa_optional_backward_kwargs,
2600
2601
                    )

2602
2603
            if ctx.fp8:
                dq = dq_fp8[(rank + i + 1) % cp_size]
2604
            if i >= (cp_size - rank - 1) or not causal:
2605
2606
2607
2608
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
2609
2610
2611
2612
2613
2614
                if ctx.qkv_format == "bshd":
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
                elif ctx.qkv_format == "sbhd":
                    # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
                    dq_ = dq_.view(-1, *dq.shape[-3:])
2615

2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
            if ctx.fp8:
                if i >= (cp_size - rank - 1) or not causal:
                    dq.copy_(dq_)
                else:
                    if ctx.qkv_format == "bshd":
                        dq[:, 0, ...].fill_(0)
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[0].fill_(0)
                        dq[1].copy_(dq_)
            elif causal:
2627
                if i > (cp_size - rank - 1):
2628
                    dq.add_(dq_)
2629
2630
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
2631
2632
                        dq.copy_(dq_)
                    else:
2633
2634
2635
2636
2637
2638
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
2639
                        elif ctx.qkv_format == "thd":
2640
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
2641
                elif i > 0:
2642
2643
2644
2645
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
2646
                    elif ctx.qkv_format == "thd":
2647
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
2648
                else:
2649
2650
2651
2652
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
2653
                    elif ctx.qkv_format == "thd":
2654
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
2655
2656
2657
2658
2659
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
2660

2661
            if attn_dbias is not None:
2662
                idx = (rank + i + 1) % cp_size
2663
                if i == (cp_size - 1) or not causal:
2664
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
2665
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2666
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
2667
2668
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
2669
2670
2671
2672
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
2673
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2674
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
2675
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
2676

2677
2678
2679
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
2680

2681
2682
2683
2684
2685
2686
2687
            if ctx.fp8:
                if i < cp_size - 1:
                    dkv = dkv_fp8_[(rank + i + 1) % cp_size]
                else:
                    dkv = dkv_fp8[(rank + i + 1) % cp_size]
            else:
                dkv = p2p_comm_buffers[(i + 1) % 2][1]
2688
2689
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
2690
2691
2692
2693
                if ctx.qkv_format in ["bshd", "sbhd"]:
                    # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
2694
            if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
2695
2696
2697
2698
2699
2700
                if ctx.qkv_format == "bshd":
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                elif ctx.qkv_format == "sbhd":
                    # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
2701
2702
2703
2704
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
2705

2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
            if ctx.fp8:
                if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
                    if ctx.qkv_format == "bshd":
                        dkv[:, :, 0, ...].copy_(dkv_)
                        dkv[:, :, 1, ...].fill_(0)
                    elif ctx.qkv_format == "sbhd":
                        dkv[:, 0, ...].copy_(dkv_)
                        dkv[:, 1, ...].fill_(0)
                else:
                    dkv.copy_(dkv_)
            elif causal:
2717
                if i == (cp_size - 1):
2718
                    if rank == 0:
2719
2720
2721
2722
2723
2724
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
2725
                        elif ctx.qkv_format == "thd":
2726
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
2727
2728
                    else:
                        dkv.add_(dkv_)
2729
2730
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
2731
2732
2733
2734
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
2735
                        elif ctx.qkv_format == "thd":
2736
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
2737
                    else:
2738
2739
2740
2741
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
2742
                        elif ctx.qkv_format == "thd":
2743
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
2744
2745
2746
2747
2748
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
2749
2750
2751
2752
2753
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
        if ctx.fp8 and ctx.use_fused_attention:
            amax_cp_bwd = amax_per_step.amax(dim=1)
            ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0]
            ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1]
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
                # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
                dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
            dq, dkv = [
                cast_from_fp8(
                    x,
                    ctx.fp8_meta["scaling_bwd"],
                    META_DQKV_CP,
                    fp8_dtype_backward,
                    TE_DType[torch.float32],
                )
                for x in [dq_fp8, dkv_fp8]
            ]
            dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]

2774
        if causal:
2775
2776
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
2777
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
2778
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
2779
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
2780
2781
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
2782
                dq = dq.view(-1, *dq.shape[-3:])
2783
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
2784
2785
2786
2787
2788
2789
2790
2791
2792
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

        if ctx.qkv_format == "thd":
            dkv_ = torch.empty(
                2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device
            )
            dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv)
            dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
            dkv = dkv_
2793

2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
        if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
            dq, dkv = [
                cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
                for x in [dq, dkv]
            ]
            dq, dk, dv = [
                Float8Tensor(
                    data=x,
                    fp8_meta=ctx.fp8_meta,
                    fp8_meta_forward=False,
                    fp8_meta_index=META_DQKV,
                    fp8_dtype=fp8_dtype_backward,
                    dtype=dout_dtype,
                )
                for x in [dq, dkv[0], dkv[1]]
            ]
        else:
            dk, dv = dkv[0], dkv[1]

2813
2814
2815
2816
        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

2817
2818
2819
        return (
            None,
            dq,
2820
2821
            dk,
            dv,
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            attn_dbias,
            None,
            None,
2839
2840
            None,
            None,
2841
        )
2842
2843


2844
@torch.compile
2845
def get_seq_chunk_ids_to_all_gathered_kv(
2846
    local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
):
    """Compute sequence chunk ids to the all-gathered KV."""
    seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv
    seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left)
    seqlen = seq_end_idx - seq_start_idx
    num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv
    chunk_ids = torch.arange(
        local_chunk_id - num_chunks + 1,
        local_chunk_id + 1,
        dtype=torch.int32,
2857
        device=device,
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
    )
    chunk_ids_to_all_gathered_kv = torch.where(
        chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1
    )
    return chunk_ids_to_all_gathered_kv


class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    """
    Attention implementation with context parallelism.
    KV all-gather between CP ranks is exposed.
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        cp_group,
        cp_stream,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
    ):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
        assert causal and not padding, f"{attn_mask_type} mask type is not supported!"
        if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
            attn_mask_type = attn_mask_type + "_bottom_right"

        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            use_fused_attention or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        max_seqlen_q = max_seqlen_q // (2 * cp_size)
        max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
        cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
        cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size)
        cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
        cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size)

        if causal:
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
                q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:])
                # [b, s, np, hn] -> [s, b, np, hn]
                k, v = [x.transpose(0, 1).contiguous() for x in [k, v]]
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
                q = q.view(2, q.shape[0] // 2, *q.shape[1:])

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]

        k_ag, _ = gather_along_first_dim(k, cp_group)
        v_ag, _ = gather_along_first_dim(v, cp_group)
        cp_stream.wait_stream(torch.cuda.current_stream())
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
        chunk_ids_to_kv_ag_per_step = [None, None]
        out_per_step = [None, None]
        softmax_lse_per_step = [None, None]
        rng_states = [None, None]
        out = torch.empty_like(q)

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
                    chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv(
                        local_seq_chunk_ids[i],
                        cp_size,
                        max_seqlen_q,
                        max_seqlen_kv,
                        (
                            max_seqlen_kv * cp_size * 2
                            if (window_size is None or window_size[0] == -1)
                            else window_size[0]
                        ),
2971
                        k.device,
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
                    )
                    chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag
                    num_kv_chunks = chunk_ids_to_kv_ag.numel()
                    if qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_ = q[:, i].contiguous()
                        # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
                        k_ = (
                            torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
                            .movedim(2, 0)
                            .contiguous()
                            .view(k.shape[1], -1, *k.shape[-2:])
                        )
                        v_ = (
                            torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
                            .movedim(2, 0)
                            .contiguous()
                            .view(v.shape[1], -1, *v.shape[-2:])
                        )
                    elif qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                        q_ = q[i].contiguous()
                        # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
                        k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
                            -1, *k.shape[-3:]
                        )
                        v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
                            -1, *v.shape[-3:]
                        )
                    if use_fused_attention:
                        out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
                            is_training,
                            max_seqlen_q,
                            max_seqlen_kv * num_kv_chunks,
                            cu_seqlens_q,
                            cu_seqlens_kv * num_kv_chunks,
                            q_,
                            k_,
                            v_,
                            TE_DType[q.dtype],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=softmax_scale,
                            dropout=dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=attn_mask_type,
                            attn_bias_type=attn_bias_type,
                            attn_bias=attn_bias,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
                            window_size=window_size,
                        )
                    else:
                        q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
                        _, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = (
                            _flash_attn_forward(
                                q_,
                                k_,
                                v_,
                                cu_seqlens_q,
                                cu_seqlens_kv * num_kv_chunks,
                                max_seqlen_q,
                                max_seqlen_kv * num_kv_chunks,
                                dropout_p,
                                softmax_scale,
                                causal=True,
                                return_softmax=False,
                                window_size=window_size,
                                **fa_optional_forward_kwargs,
                            )
                        )

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if qkv_format == "bshd":
                        out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1]))
                    elif qkv_format == "sbhd":
                        out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1]))

        torch.cuda.current_stream().wait_stream(cp_stream)

        if use_fused_attention:
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
        else:
            out = out.view(-1, *out.shape[-2:])

        ctx.save_for_backward(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *chunk_ids_to_kv_ag_per_step,
            *out_per_step,
            *softmax_lse_per_step,
            *rng_states,
        )
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_mask_type = attn_mask_type
        ctx.attn_bias_type = attn_bias_type
        ctx.deterministic = deterministic
        ctx.use_fused_attention = use_fused_attention
        ctx.window_size = window_size
        return out

    @staticmethod
    def backward(ctx, dout):
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)

        (q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = (
            ctx.saved_tensors[:7]
        )
        chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9]
        out_per_step = ctx.saved_tensors[9:11]
        softmax_lse_per_step = ctx.saved_tensors[11:13]
        rng_states = ctx.saved_tensors[13:15]

        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

        dout = dout.view_as(q)
        dq = torch.empty_like(q)
        dk = torch.zeros(
            (2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device
        )
        dv = torch.zeros_like(dk)
        dq_per_step = [None, None]
        dk_per_step = [None, None]
        dv_per_step = [None, None]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
        # synchronize dkv update across steps
        dkv_update_done = torch.cuda.Event()

        k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
        v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
        ctx.cp_stream.wait_stream(torch.cuda.current_stream())
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]

        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
                    chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i]
                    num_kv_chunks = chunk_ids_to_kv_ag.numel()
                    out_ = out_per_step[i]
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_ = q[:, i].contiguous()
                        # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
                        k_ = (
                            torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
                            .movedim(2, 0)
                            .contiguous()
                            .view(k.shape[1], -1, *k.shape[-2:])
                        )
                        v_ = (
                            torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
                            .movedim(2, 0)
                            .contiguous()
                            .view(v.shape[1], -1, *v.shape[-2:])
                        )
                        dout_ = dout[:, i].contiguous().view_as(out_)
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                        q_ = q[i].contiguous()
                        # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
                        k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
                            -1, *k.shape[-3:]
                        )
                        v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
                            -1, *v.shape[-3:]
                        )
                        dout_ = dout[i].contiguous().view_as(out_)
                    if ctx.use_fused_attention:
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
                        aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_kv * num_kv_chunks,
                            cu_seqlens_q,
                            cu_seqlens_kv * num_kv_chunks,
                            q_,
                            k_,
                            v_,
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[k.dtype],
                            aux_ctx_tensors,
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=ctx.attn_mask_type,
                            attn_bias_type=ctx.attn_bias_type,
                            window_size=ctx.window_size,
                        )
                    else:
                        q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
                        _flash_attn_backward(
                            dout_,
                            q_,
                            k_,
                            v_,
                            out_,
                            softmax_lse_per_step[i],
                            dq_per_step[i],
                            dk_per_step[i],
                            dv_per_step[i],
                            cu_seqlens_q,
                            cu_seqlens_kv * num_kv_chunks,
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_kv * num_kv_chunks,
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            True,
                            window_size=ctx.window_size,
                            rng_state=rng_states[i],
                            **fa_optional_backward_kwargs,
                        )

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1]
                    num_kv_chunks = chunk_ids_to_kv_ag.numel()
                    if ctx.qkv_format == "bshd":
                        dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1]))
                        dk_per_step[i - 1] = (
                            dk_per_step[i - 1]
                            .view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:])
                            .movedim(0, 2)
                            .contiguous()
                        )
                        dv_per_step[i - 1] = (
                            dv_per_step[i - 1]
                            .view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:])
                            .movedim(0, 2)
                            .contiguous()
                        )
                    elif ctx.qkv_format == "sbhd":
                        dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1]))
                        dk_per_step[i - 1] = dk_per_step[i - 1].view(
                            num_kv_chunks, -1, *k.shape[-3:]
                        )
                        dv_per_step[i - 1] = dv_per_step[i - 1].view(
                            num_kv_chunks, -1, *v.shape[-3:]
                        )

                    # wait until dkv update of last step is done
                    if i > 1:
                        flash_attn_streams[i - 1].wait_event(dkv_update_done)
                    dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1])
                    dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1])
                    if i < len(local_seq_chunk_ids):
                        flash_attn_streams[i - 1].record_event(dkv_update_done)

        torch.cuda.current_stream().wait_stream(ctx.cp_stream)

        dk = dk.view(-1, *dk.shape[-3:])
        dv = dv.view(-1, *dv.shape[-3:])
        dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
        dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)

        if ctx.qkv_format == "bshd":
            dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
            dk = dk.transpose(0, 1).contiguous()
            dv = dv.transpose(0, 1).contiguous()
        elif ctx.qkv_format == "sbhd":
            dq = dq.view(-1, *dq.shape[-3:])

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


3294
def attn_forward_func_with_cp(
3295
3296
3297
3298
3299
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
3300
    cu_seqlens_kv,
3301
    max_seqlen_q,
3302
    max_seqlen_kv,
3303
3304
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
3305
3306
3307
3308
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
3309
    cp_comm_type,
3310
3311
3312
3313
3314
3315
3316
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
3317
    window_size=None,
3318
3319
    fp8=False,
    fp8_meta=None,
3320
) -> torch.Tensor:
3321
3322
3323
3324
    """
    Attention implementation with context parallelism.
    """

3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert (
        qkv_format != "thd"
        or not use_fused_attention
        or attn_mask_type in ["padding", "padding_causal"]
    ), (
        f"Context parallelism is not supported for {attn_mask_type} mask type and "
        f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
    )
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
3345
3346
3347
    assert (
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
    ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
3348
3349
3350

    sliding_window_attn = (
        window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
3351
    )
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399

    if sliding_window_attn or cp_comm_type == "all_gather":
        out = AttnFuncWithCPAndKVAllGather.apply(
            is_training,
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            dropout_p,
            cp_group,
            cp_stream,
            softmax_scale,
            qkv_format,
            attn_mask_type,
            attn_bias_type,
            attn_bias,
            deterministic,
            use_fused_attention,
            window_size,
        )
    elif cp_comm_type == "p2p":
        out = AttnFuncWithCPAndKVP2P.apply(
            is_training,
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            dropout_p,
            cp_group,
            cp_global_ranks,
            cp_stream,
            softmax_scale,
            qkv_format,
            attn_mask_type,
            attn_bias_type,
            attn_bias,
            deterministic,
            use_fused_attention,
3400
3401
            fp8,
            fp8_meta,
3402
3403
3404
3405
        )
    else:
        raise ValueError(f"Unsupported communication type: {cp_comm_type}!")

3406
3407
3408
    return out


3409
3410
3411
3412
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
3413

3414
3415
3416
    def __init__(
        self,
        dim: int,
3417
        rotary_percent: float = 1.0,
3418
3419
3420
3421
3422
3423
3424
3425
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
3426
3427
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
3428
3429
3430
3431
3432
3433
3434
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
3435
3436
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
3437
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
3438
3439
3440
3441
3442
3443
3444
        inv_freq = 1.0 / (
            10000
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
3445
        self.register_buffer("inv_freq", inv_freq)
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
3459
3460
3461
3462
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
3463

3464
3465
3466
3467
3468
3469
3470
3471
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
3472
3473
3474
3475
3476
3477
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

3478
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
3479
3480
3481
3482
3483
3484
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
3503
3504
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
3505
3506
3507
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
3508
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
3519
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


3535
3536
3537
3538
3539
3540
3541
3542
3543
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


3544
def apply_rotary_pos_emb(
3545
3546
3547
3548
3549
3550
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
3551
    """
3552
    Apply rotary positional embedding tensor to the input tensor.
3553

3554
3555
3556
    Parameters
    ----------
    t: torch.Tensor
3557
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
3570
    """
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

3582
3583
3584
3585
3586
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
3587
3588
3589
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
3590
    freqs = freqs[:cur_seq_len]
3591
    if tensor_format == "bshd":
3592
3593
3594
3595
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
3596

3597
3598
3599
3600
3601
3602
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
3603
    t = (t * cos_) + (_rotate_half(t) * sin_)
3604
3605
3606
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
3607
class _SplitAlongDim(torch.autograd.Function):
3608
3609
3610
    """"""

    @staticmethod
3611
3612
3613
3614
3615
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
3616
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
3617
3618
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
3619
        if isinstance(mixed_x_layer, Float8Tensor):
3620
3621
3622
3623
3624
3625
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
                    data=x,
                )
                for x in torch.split(
3626
3627
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
3628
3629
3630
3631
                    dim=split_dim,
                )
            )
        return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
3632
3633

    @staticmethod
3634
    def backward(ctx, *grad_outputs):
3635
3636
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
3637
3638
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
3639
3640
3641
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
3642
3643
3644
3645
3646
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

3647
3648
3649
3650
3651
3652
3653
3654
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
3655
3656
3657
3658
3659
3660
3661
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
3662
3663
3664
                    noop_ok = False
                    break
            if noop_ok:
3665
3666
3667
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
3668
3669
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
3670
3671
3672
3673
3674
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
3675
3676
3677
3678
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
3679
3680
3681
3682
3683
3684
3685
            return (
                Float8Tensor.make_like(
                    grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim)
                ),
                None,
                None,
            )
3686
3687
        noop_ok = True
        strides = grad_outputs[0].stride()
3688
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
3689
        shape = list(grad_outputs[0].shape)
3690
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
3691
3692
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
3693
3694
3695
3696
3697
3698
3699
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
3700
3701
3702
                noop_ok = False
                break
        if noop_ok:
3703
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
3704
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
3705
            new_shape[split_dim] = sum(split_sizes)
3706
3707
3708
3709
3710
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
3711
            )
cyanguwa's avatar
cyanguwa committed
3712
            return ret, None, None
3713

3714
        return torch.cat(grad_outputs, dim=split_dim), None, None
3715
3716
3717
3718
3719
3720
3721
3722
3723


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
3724
        softmax_scale: float,
3725
        attention_type: str = "self",
3726
3727
3728
3729
3730
3731
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

3732
        self.softmax_scale = softmax_scale
3733
        self.attention_type = attention_type
3734
3735
3736
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

3737
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
3738
3739
3740
3741
3742
3743

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

3744
3745
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
3746
3747
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
3748

3749
3750
3751
3752
3753
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3754
        qkv_layout: str = "sbh3d",
3755
3756
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
3757
        attn_mask_type: str = "causal",
3758
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3759
3760
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
3761
        alibi_slopes: Optional[torch.Tensor] = None,
3762
    ) -> torch.Tensor:
3763
        """Unfused attention fprop"""
3764
3765
3766
3767
3768
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
3769
            # convert to sbhd and use sbhd implementation for now
3770
3771
3772
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
        batch_size, max_seqlen_q, max_seqlen_kv = (
            query_layer.shape[1],
            query_layer.shape[0],
            key_layer.shape[0],
        )
        if "padding" in attn_mask_type:
            if self.attention_type == "self":
                assert attention_mask.shape == (
                    batch_size,
                    1,
                    1,
                    max_seqlen_q,
                ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
                attention_mask = torch.logical_or(
                    attention_mask.squeeze(1).unsqueeze(3), attention_mask
                )
            else:
                assert (
                    len(attention_mask) == 2
                    and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
                    and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
                ), (
                    "attention_mask should be a tuple of two tensors with shapes "
                    "[b, 1, 1, sq] and [b, 1, 1, skv]!"
                )
                attention_mask = torch.logical_or(
                    attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
                )
            mask = attention_mask.squeeze(1).logical_not()
            actual_seqlens_q = mask[:, :, 0].sum(dim=1)
            actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
            mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
                1, 1, max_seqlen_q, 1
            ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
                1, 1, 1, max_seqlen_kv
            )
            if attn_mask_type == "padding_causal":
                attention_mask = torch.logical_or(
                    torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
                    attention_mask,
                )
            if attn_mask_type == "padding_causal_bottom_right":
                attention_mask = torch.logical_or(
                    torch.where(
                        mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
                        + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
                        < 0,
                        1,
                        0,
                    ),
                    attention_mask,
                )
3825

3826
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
3827
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
3828
3829
3830
3831
3832
3833
3834
3835
3836

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

3837
        if key_layer.shape[2] != query_layer.shape[2]:
3838
3839
3840
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
3841
            key_layer = key_layer.repeat_interleave(
3842
3843
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
3844
            value_layer = value_layer.repeat_interleave(
3845
3846
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
3847

3848
        # [sq, b, np, hn] -> [sq, b * np, hn]
3849
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
3850
3851
3852
3853
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
3854
3855
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
3856
3857
3858
3859
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
3860
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
3861
3862
3863
            device=torch.cuda.current_device(),
        )

3864
3865
3866
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

3867
        scale = self.softmax_scale
3868
        if apply_qk_layer_scaling:
3869
            scale /= self.layer_number
3870
3871

        # Raw attention scores. [b * np, sq, sk]
3872
3873
3874
3875
3876
3877
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
3878
                alpha=scale,
3879
            ).view(*output_size)
3880
3881
3882
3883
3884
3885
3886

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
3887
            matmul_result = matmul_result.view(*output_size) + core_attention_bias
3888
            matmul_result *= scale
3889

3890
3891
3892
3893
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
3894
                _, core_attention_bias = get_alibi(
3895
3896
3897
                    output_size[1],
                    output_size[2],
                    output_size[3],
3898
3899
                    actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
                    actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
3900
3901
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
3902
                )
3903
3904
3905
3906
3907
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
3908
                alpha=scale,
3909
            )
3910
3911
            matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
                dtype=query_layer.dtype
3912
            )
3913
3914
3915

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
3916
        attention_probs = self.scale_mask_softmax(
3917
            matmul_result, attention_mask, attn_mask_type, softmax_scale
3918
        )
3919

3920
3921
3922
3923
3924
        # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
        # the columns (pad tokens from k) are already zeroed out during softmax
        if "padding" in attn_mask_type:
            attention_probs = attention_probs.masked_fill(attention_mask, 0)

3925
3926
3927
3928
3929
3930
3931
3932
3933
3934
3935
3936
3937
3938
3939
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
3940
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
3941
3942

        # change view [b * np, sq, sk]
3943
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
3944
3945
3946
3947
3948
3949
3950

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

3951
        if qkv_format == "sbhd":
3952
3953
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
3954

3955
3956
3957
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

3958
        if qkv_format == "bshd":
3959
3960
3961
3962
3963
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
3964
3965
3966
3967
3968
3969

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
3970
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
3971
3972

    @staticmethod
3973
3974
3975
3976
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
3977
        value_layer: torch.Tensor,
3978
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988
3989
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
3990
3991
3992
3993
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
3994
        dv: torch.Tensor,
3995
3996
3997
3998
3999
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

4000

4001
def get_qkv_layout(
4002
4003
4004
4005
4006
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
4007
    """Get qkv layout.
4008

4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
4020
        `d` head size, and `t` the total number of tokens in a batch, i.e.
4021
4022
4023
4024
4025
4026
4027
4028
4029
4030
4031
4032
4033
4034
4035
4036
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
4037

4038
4039
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
4040

4041
4042
4043
4044
4045
4046
4047
4048
4049
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
        stride = k.stride()
4050
4051
4052
        check_strides_kv = torch.equal(
            torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1]
        )
4053
4054
4055
4056

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
4057
        check_shapes_kv = shape[:-1] == v.shape[:-1]
4058
4059

        last_dim_size = q.shape[-1]
4060
4061
4062
        check_last_dim_offsets_qkv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
4063
        last_dim_size = k.shape[-1]
4064
4065
4066
        check_last_dim_offsets_kv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])
        )
4067
4068

        last_two_dims_size = q.shape[-1] * q.shape[-2]
4069
4070
4071
        check_last_two_dims_offsets_qkv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
4072
        last_two_dims_size = k.shape[-1] * k.shape[-2]
4073
4074
4075
        check_last_two_dims_offsets_kv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])
        )
4076

4077
4078
4079
4080
        if (
            check_ptrs_qkv
            and check_strides_qkv
            and check_shapes_qkv
4081
            and check_last_two_dims_offsets_qkv
4082
4083
            and not check_last_dim_offsets_qkv
        ):
4084
            # sb3hd, bs3hd, t3hd
4085
4086
4087
4088
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
        elif (
            check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv
        ):
4089
            # sbh3d, bsh3d, th3d
4090
4091
4092
4093
4094
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
        elif (
            check_ptrs_kv
            and check_strides_kv
            and check_shapes_kv
4095
            and check_last_two_dims_offsets_kv
4096
4097
            and not check_last_dim_offsets_kv
        ):
4098
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
4099
4100
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv:
4101
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
4102
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
4103
4104
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
4105
            qkv_layout = "_".join(list([qkv_format]) * 3)
4106
        else:
4107
            qkv_layout = "not_supported"
4108
4109
4110
4111

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
4112
    if qkv_layout == "not_supported":
4113
4114
4115
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
4116
    if qkv_layout == "not_supported":
4117
4118
        raise Exception("The provided qkv memory layout is not supported!")

4119
    return qkv_layout, q, k, v
4120

4121

4122
def check_set_window_size(
4123
4124
4125
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
4126
4127
4128
4129
4130
4131
4132
4133
    """Check if sliding window size is compliant with attention mask type.
    If not, set it to the appropriate size.

         attn_mask_type                              |   window_size
    -------------------------------------------------------------------------
    no_mask, padding, arbitrary                      | (-1, -1) or (>=0, >=0)
    causal, padding_causal                           | (-1,  0) or (>=0, 0)
    causal_bottom_right, padding_causal_bottom_right | (-1,  0) or (>=0, 0)
4134
    """
4135
    orig_window_size = window_size
4136
    if "causal" in attn_mask_type:
4137
        if orig_window_size is None:
4138
            window_size = (-1, 0)
4139
4140
4141
        elif orig_window_size == (-1, -1) or (
            orig_window_size[0] >= 0 and orig_window_size[1] != 0
        ):
4142
4143
4144
4145
            window_size = (orig_window_size[0], 0)
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
4146
        elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
4147
4148
4149
4150
            assert False, (
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
    elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
4151
4152
4153
        if orig_window_size is None:
            window_size = (-1, -1)
        elif orig_window_size == (-1, 0):
4154
            window_size = (-1, -1)
4155
4156
4157
            warnings.warn(
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
4158
        elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
4159
4160
4161
4162
4163
            assert False, (
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
    else:
        assert False, "Invalid attn_mask_type: " + attn_mask_type
4164
    return window_size
4165

4166

4167
class FlashAttention(torch.nn.Module):
4168
    """Dot product attention, using HazyResearch flash-attn package:
4169
    https://github.com/Dao-AILab/flash-attention
4170
4171
4172
4173
    """

    def __init__(
        self,
4174
        softmax_scale: float,
4175
4176
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
4177
4178
        attention_type: str = "self",
        layer_number: Optional[int] = None,
4179
        deterministic: bool = False,
4180
4181
4182
4183
4184
4185
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
4186
4187
4188
        assert (
            _flash_attn_version <= _flash_attn_max_version
        ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
4189

4190
        self.softmax_scale = softmax_scale
4191
4192
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
4193
4194
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
4195
        self.deterministic = deterministic
4196
4197
4198
4199
4200
4201

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4202
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4203
4204
4205
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
4206
4207
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
4208
        attn_mask_type: str = "causal",
4209
        window_size: Optional[Tuple[int, int]] = None,
4210
        alibi_slopes: Optional[torch.Tensor] = None,
4211
        cp_group: Optional[dist_group_type] = None,
4212
        cp_global_ranks: List[int] = None,
4213
        cp_stream: torch.cuda.Stream = None,
4214
        cp_comm_type: str = "p2p",
4215
4216
4217
4218
    ) -> torch.Tensor:
        """flash-attn fprop"""

        assert (
4219
4220
4221
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
4222
        ), "FlashAttention currently only supports FP16 and BF16."
4223
4224
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
4225
        ), "FlashAttention currently only supports CUDA tensors."
4226
4227
        assert (
            qkv_layout in QKVLayouts
4228
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
4229

4230
4231
        cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
        context_parallel = cp_size > 1
4232

4233
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
4234

4235
        if qkv_format == "sbhd":
4236
            # For now just 128, will make it more general in the future
4237
4238
4239
4240
4241
4242
4243
4244
            if (
                query_layer.shape[-1] == 128
                and query_layer.shape[0] * query_layer.shape[1] >= 512
                and qkv_layout == "sbh3d"
            ):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                    query_layer, key_layer, value_layer
                )
4245
            else:
4246
4247
4248
4249
4250
4251
4252
                query_layer, key_layer, value_layer = [
                    x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        elif qkv_format in ["bshd", "thd"]:
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
4253

4254
        batch_size = query_layer.shape[0]
4255

4256
        if qkv_format in ["sbhd", "bshd"]:
4257
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
4258
4259
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
4260
4261
4262
4263
4264
4265
4266
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

4267
            if "padding" in attn_mask_type:
4268
                assert not context_parallel, "Padding mask not supported with context parallelism!"
4269
4270
4271
4272
4273

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
4274
                    if cu_seqlens_q is None:
4275
4276
4277
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
4278
4279
4280
4281
4282
4283
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
4284
4285
                    )
                else:
4286
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
4287
4288
4289
4290
4291
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
4292
4293
4294
4295
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
4296
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
4297
            else:
4298
4299
4300
4301
4302
4303
4304
4305
4306
4307
4308
4309
4310
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
4311
4312
4313
4314
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
4315
4316
4317
4318
4319
4320
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
4321

4322
        if context_parallel:
4323
4324
4325
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
4326
            with self.attention_dropout_ctx():
4327
                output = attn_forward_func_with_cp(
4328
4329
4330
4331
4332
4333
4334
4335
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
4336
4337
                    cu_seqlens_q,
                    cu_seqlens_kv,
4338
                    self.attention_dropout if self.training else 0.0,
4339
4340
4341
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
4342
                    cp_comm_type,
4343
                    softmax_scale=self.softmax_scale,
4344
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
4345
                    attn_mask_type=attn_mask_type,
4346
                    deterministic=self.deterministic,
4347
                    window_size=window_size,
4348
4349
                )
        else:
4350
4351

            from .cpu_offload import CPUOffloadEnabled
4352

4353
4354
4355
4356
4357
4358
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

4359
            with self.attention_dropout_ctx():
4360
                fa_optional_forward_kwargs = {}
4361
4362
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
4363
4364
4365
4366
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
4367
4368
                if _flash_attn_2_5_7_plus:
                    fa_optional_forward_kwargs["block_table"] = None
4369
                output = flash_attn_forward_func(
4370
4371
4372
4373
4374
4375
4376
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
4377
                    self.attention_dropout if self.training else 0.0,
4378
4379
                    softmax_scale=self.softmax_scale,
                    causal="causal" in attn_mask_type,
4380
                    **fa_optional_forward_kwargs,
4381
                )
4382

4383
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
4384
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
4385

4386
        if qkv_format == "sbhd":
4387
            # (bs)hd -> bs(hd) -> sb(hd)
4388
4389
4390
            output = (
                output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous()
            )
4391
        elif qkv_format == "bshd":
4392
            # (bs)hd -> bs(hd)
4393
            output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous()
4394
        elif qkv_format == "thd":
4395
4396
            # thd -> t(hd)
            output = output.view(output.shape[0], -1).contiguous()
4397
4398

        return output
4399

4400

4401
def _combine_tensors(
4402
4403
4404
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
4405
4406
4407
4408
4409
4410
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
4411
    new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
4412
    if isinstance(tensors[0], Float8Tensor):
4413
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
4414
4415
4416
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
4417
4418
4419
4420
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor)
4421
    else:
4422
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
4423
        combined_tensor.set_(
4424
4425
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
        )
4426
4427

    return combined_tensor
4428

4429

4430
4431
4432
4433
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
4434
4435
4436
4437
4438
    def forward(
        ctx,
        is_training,
        max_seqlen,
        cu_seqlens,
4439
        cu_seqlens_padded,
4440
4441
4442
4443
4444
4445
4446
4447
4448
        qkv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
4449
        window_size,
4450
4451
4452
4453
4454
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
4455
        deterministic,
4456
    ):
4457
4458
        if fp8:
            if fp8_meta["recipe"].fp8_mha:
4459
                assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
4460
4461
4462
4463
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
4464
4465
4466
4467
4468
            qkv_group = len(qkv_layout.split("_"))
            assert qkv_group == 1, (
                "qkv layout should conform to 3hd or h3d, e.g. sb3hd,                 but found"
                f" {qkv_layout}."
            )
4469
4470
4471
4472
            if fp8_meta["recipe"].fp8_mha:
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
4473
4474
4475
                qkv_fp8 = cast_to_fp8(
                    qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(qkv.shape)
4476
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
4477
4478
4479
4480
4481
4482
4483
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
4484
                cu_seqlens_padded,
4485
4486
4487
4488
4489
4490
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
4491
4492
4493
4494
4495
4496
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4497
                window_size,
4498
4499
                rng_gen,
            )
4500
            if fp8_meta["recipe"].fp8_mha:
4501
4502
                out_ret = Float8Tensor(
                    data=out_fp8,
4503
4504
4505
4506
4507
4508
4509
4510
4511
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4512
4513
4514
4515
4516
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
4517
4518
4519
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
4520
4521
                qkv = cast_from_fp8(
                    qkv_c._data,
4522
                    fp8_meta["scaling_fwd"],
4523
4524
4525
4526
                    META_QKV,
                    fp8_dtype_forward,
                    TE_DType[qkv.dtype],
                ).view(qkv.shape)
4527
4528
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4529
4530
4531
4532
4533
4534
4535
4536
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
            fp8_tensors = (
                qkv_fp8,
                out_fp8,
4537
                fp8_meta["scaling_fwd"].scale.clone(),
4538
4539
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
4540
4541
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
4542
4543
4544
4545
4546
4547
4548
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
4549
                cu_seqlens_padded,
4550
4551
4552
4553
4554
4555
4556
4557
4558
4559
4560
4561
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4562
                window_size,
4563
4564
                rng_gen,
            )
4565
4566
4567
4568
4569
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
4570
        ctx.save_for_backward(
4571
            *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
4572
        )
4573
        ctx.fp8_meta = fp8_meta
4574
4575
4576
4577
4578
4579
4580
4581
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
4582
        ctx.window_size = window_size
4583
        ctx.fused_attention_backend = (
4584
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
4585
        )
4586
        ctx.use_FAv2_bwd = use_FAv2_bwd
4587
        ctx.deterministic = deterministic
4588

4589
        return out_ret
4590
4591
4592

    @staticmethod
    def backward(ctx, d_out):
4593
        if ctx.fp8_meta["recipe"].fp8_mha:
4594
4595
4596
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
4597
4598
4599
            d_out_f8tensor = d_out
            d_out = d_out._data

4600
        d_out = d_out.contiguous()
4601
4602
4603
4604
        (
            qkv,
            out,
            cu_seqlens,
4605
            cu_seqlens_padded,
4606
4607
4608
4609
4610
4611
            qkv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
4612
4613
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
4614
        if ctx.use_FAv2_bwd:
4615
            softmax_lse, rng_state = aux_ctx_tensors
4616
4617
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
4618
4619
4620
            d_out, q, k, v, out = [
                maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
            ]
4621
            flash_attn_cuda_bwd(
4622
4623
4624
4625
4626
4627
4628
4629
4630
4631
4632
4633
4634
4635
4636
4637
4638
4639
4640
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dqkv[:, 0],
                dqkv[:, 1],
                dqkv[:, 2],
                cu_seqlens,
                cu_seqlens,
                ctx.max_seqlen,
                ctx.max_seqlen,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
4641
            )
4642
            dqkv = dqkv[..., : d_out.shape[-1]]
4643
        else:
4644
4645
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
4646
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
4647
                    fp8_dtype_backward = get_fp8_te_dtype(
4648
4649
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
4650
4651
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
4652
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
4653
4654
4655
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
4656
4657
4658
4659
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
4660
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
4661
4662
4663
4664
4665
4666
4667
4668
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
4669
                        ctx.fused_attention_backend,
4670
                        cu_seqlens_padded,
4671
4672
4673
4674
4675
4676
4677
4678
4679
4680
4681
4682
4683
4684
4685
4686
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4687
4688
                        ctx.window_size,
                        ctx.deterministic,
4689
                    )
4690
                    if ctx.fp8_meta["recipe"].fp8_mha:
4691
4692
                        dqkv = Float8Tensor(
                            data=dqkv_fp8,
4693
4694
4695
4696
4697
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4698
                        )
4699
                    else:
4700
4701
4702
4703
4704
4705
4706
4707
4708
4709
                        dqkv_c_fp8 = dqkv_fp8.view(
                            -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                        )
                        dqkv = cast_from_fp8(
                            dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dqkv_fp8.shape)
4710
4711
4712
4713
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
4714
4715
4716
4717
4718
4719
4720
4721
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
4722
                        ctx.fused_attention_backend,
4723
                        cu_seqlens_padded,
4724
4725
4726
4727
4728
4729
4730
4731
4732
4733
4734
4735
4736
4737
4738
4739
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4740
4741
                        ctx.window_size,
                        ctx.deterministic,
4742
                    )
4743

4744
4745
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
4746
4747
4748
4749
4750
4751
4752
4753
4754
4755
4756
4757
4758
4759
4760
4761
4762
4763
4764
4765
4766
            return (
                None,
                None,
                None,
                None,
                dqkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
4767
4768
                None,
                None,
4769
            )
4770
        # else, return (dqkv, dbias)
4771
4772
4773
4774
4775
4776
4777
4778
4779
4780
4781
4782
4783
4784
4785
4786
4787
4788
4789
4790
4791
        return (
            None,
            None,
            None,
            None,
            dqkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4792
4793
            None,
            None,
4794
        )
4795

4796

4797
4798
4799
4800
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
4801
4802
4803
4804
4805
4806
4807
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
4808
4809
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
4810
4811
4812
4813
4814
4815
4816
4817
4818
4819
        q,
        kv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
4820
        window_size,
4821
4822
4823
4824
4825
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
4826
        deterministic,
4827
    ):
4828
4829
        if fp8:
            if fp8_meta["recipe"].fp8_mha:
4830
4831
4832
                assert isinstance(q, Float8Tensor) and isinstance(
                    kv, Float8Tensor
                ), "q/kv must be Float8Tensors for FP8 MHA."
4833
4834
4835
4836
4837
4838
4839
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4840
4841
4842
4843
4844
4845
4846
4847
                qkv_group = len(qkv_layout.split("_"))
                assert qkv_group == 2, (
                    "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd,              "
                    f"       but found {qkv_layout}."
                )
                q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
                    q.shape
                )
4848
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4849
4850
4851
                kv_fp8 = cast_to_fp8(
                    kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(kv.shape)
4852
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
4853
4854
4855
4856
4857
4858
4859
4860
4861
4862
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                kv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
4863
4864
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4865
4866
4867
4868
4869
4870
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
4871
4872
4873
4874
4875
4876
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4877
                window_size,
4878
4879
                rng_gen,
            )
4880
            if fp8_meta["recipe"].fp8_mha:
4881
4882
                out_ret = Float8Tensor(
                    data=out_fp8,
4883
4884
4885
4886
4887
4888
4889
4890
4891
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4892
4893
4894
4895
4896
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
4897
4898
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
4899
4900
4901
                q = cast_from_fp8(
                    q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype]
                ).view(q.shape)
4902
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4903
4904
                kv = cast_from_fp8(
                    kv_c._data,
4905
                    fp8_meta["scaling_fwd"],
4906
4907
4908
4909
                    META_QKV,
                    fp8_dtype_forward,
                    TE_DType[kv.dtype],
                ).view(kv.shape)
4910
4911
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4912
4913
4914
4915
4916
4917
4918
4919
4920
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
            fp8_tensors = (
                q_fp8,
                kv_fp8,
                out_fp8,
4921
                fp8_meta["scaling_fwd"].scale.clone(),
4922
4923
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
4924
4925
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
4926
4927
4928
4929
4930
4931
4932
4933
4934
4935
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                kv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
4936
4937
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4938
4939
4940
4941
4942
4943
4944
4945
4946
4947
4948
4949
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4950
                window_size,
4951
4952
                rng_gen,
            )
4953
4954
4955
4956
4957
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
4958
4959
4960
4961
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
4962
4963
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4964
4965
4966
            *fp8_tensors,
            *aux_ctx_tensors,
        )
4967
        ctx.fp8_meta = fp8_meta
4968
4969
4970
4971
4972
4973
4974
4975
4976
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
4977
        ctx.window_size = window_size
4978
        ctx.fused_attention_backend = (
4979
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
4980
        )
4981
        ctx.use_FAv2_bwd = use_FAv2_bwd
4982
        ctx.deterministic = deterministic
4983

4984
        return out_ret
4985
4986
4987

    @staticmethod
    def backward(ctx, d_out):
4988
        if ctx.fp8_meta["recipe"].fp8_mha:
4989
4990
4991
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
4992
4993
4994
            d_out_f8tensor = d_out
            d_out = d_out._data

4995
        d_out = d_out.contiguous()
4996
4997
4998
4999
5000
5001
        (
            q,
            kv,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
5002
5003
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
5004
5005
5006
5007
5008
5009
5010
            q_fp8,
            kv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
5011
5012
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
5013
        if ctx.use_FAv2_bwd:
5014
            softmax_lse, rng_state = aux_ctx_tensors
5015
5016
5017
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
5018
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
5019
            flash_attn_cuda_bwd(
5020
5021
5022
5023
5024
5025
5026
5027
5028
5029
5030
5031
5032
5033
5034
5035
5036
5037
5038
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dkv[:, 0],
                dkv[:, 1],
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
5039
            )
5040
5041
            dq = dq[..., : d_out.shape[-1]]
            dkv = dkv[..., : d_out.shape[-1]]
5042
        else:
5043
5044
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
5045
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
5046
                    fp8_dtype_backward = get_fp8_te_dtype(
5047
5048
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
5049
5050
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
5051
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
5052
5053
5054
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
5055
5056
5057
5058
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
5059
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
5060
5061
5062
5063
5064
5065
5066
5067
5068
5069
5070
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        kv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
5071
                        ctx.fused_attention_backend,
5072
5073
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
5074
5075
5076
5077
5078
5079
5080
5081
5082
5083
5084
5085
5086
5087
5088
5089
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5090
5091
                        ctx.window_size,
                        ctx.deterministic,
5092
                    )
5093
                    if ctx.fp8_meta["recipe"].fp8_mha:
5094
5095
                        dq = Float8Tensor(
                            data=dq_fp8,
5096
5097
5098
5099
5100
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
5101
5102
5103
                        )
                        dkv = Float8Tensor(
                            data=dkv_fp8,
5104
5105
5106
5107
5108
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
5109
                        )
5110
5111
5112
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
5113
5114
5115
5116
5117
5118
5119
5120
5121
5122
5123
5124
5125
5126
5127
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(
                            -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                        )
                        dkv = cast_from_fp8(
                            dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dkv_fp8.shape)
5128
5129
5130
5131
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
5132
5133
5134
5135
5136
5137
5138
5139
5140
5141
5142
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        kv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
5143
                        ctx.fused_attention_backend,
5144
5145
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
5146
5147
5148
5149
5150
5151
5152
5153
5154
5155
5156
5157
5158
5159
5160
5161
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5162
5163
                        ctx.window_size,
                        ctx.deterministic,
5164
                    )
5165

5166
5167
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
5168
5169
5170
5171
5172
5173
5174
5175
5176
5177
5178
5179
5180
5181
5182
5183
5184
5185
5186
5187
5188
5189
5190
5191
5192
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
5193
5194
                None,
                None,
5195
            )
5196
        # else, return (dqkv, dbias)
5197
5198
5199
5200
5201
5202
5203
5204
5205
5206
5207
5208
5209
5210
5211
5212
5213
5214
5215
5216
5217
5218
5219
5220
5221
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
5222
5223
            None,
            None,
5224
5225
        )

5226

5227
5228
5229
5230
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
5231
5232
5233
5234
5235
5236
5237
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
5238
5239
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
5240
5241
5242
5243
5244
5245
5246
5247
5248
5249
5250
        q,
        k,
        v,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
5251
        window_size,
5252
5253
5254
5255
5256
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
5257
        deterministic,
5258
    ):
5259
5260
5261
5262
        if fp8:
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
5263
5264
                assert (
                    isinstance(q, Float8Tensor)
5265
                    and isinstance(k, Float8Tensor)
5266
5267
                    and isinstance(v, Float8Tensor)
                ), "q/k/v must be Float8Tensors for FP8 MHA."
5268
5269
5270
5271
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
5272
                qkv_group = len(qkv_layout.split("_"))
5273
                if qkv_group == 1:
5274
5275
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
5276
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
5277
5278
5279
5280
                    qkv_fp8 = cast_to_fp8(
                        qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1])
5281
5282
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
5283
5284
5285
5286
5287
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
5288
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
5289
5290
5291
5292
                    kv_fp8 = cast_to_fp8(
                        kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1])
5293
5294
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
5295
5296
5297
5298
5299
5300
5301
5302
5303
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    k_fp8 = cast_to_fp8(
                        k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(k.shape)
                    v_fp8 = cast_to_fp8(
                        v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(v.shape)
5304
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
5305
5306
5307
5308
5309
5310
5311
5312
5313
5314
5315
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
5316
5317
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
5318
5319
5320
5321
5322
5323
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
5324
5325
5326
5327
5328
5329
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5330
                window_size,
5331
5332
                rng_gen,
            )
5333
            if fp8_meta["recipe"].fp8_mha:
5334
5335
                out_ret = Float8Tensor(
                    data=out_fp8,
5336
5337
5338
5339
5340
5341
5342
5343
5344
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
5345
5346
5347
5348
5349
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
5350
5351
5352
5353
            out_save = out_ret

            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                # 1: qkv packed, 2: kv packed, 3: qkv separate
5354
                qkv_group = len(qkv_layout.split("_"))
5355
                if qkv_group == 1:
5356
5357
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
5358
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
5359
5360
                    qkv_no_fp8 = cast_from_fp8(
                        qkv_c._data,
5361
                        fp8_meta["scaling_fwd"],
5362
5363
5364
5365
5366
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[qkv.dtype],
                    ).view(qkv.shape)
                    q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1])
5367
5368
                    q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                if qkv_group == 2:
5369
5370
                    q = cast_from_fp8(
                        q._data,
5371
                        fp8_meta["scaling_fwd"],
5372
5373
5374
5375
5376
5377
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
5378
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
5379
5380
                    kv_no_fp8 = cast_from_fp8(
                        kv_c._data,
5381
                        fp8_meta["scaling_fwd"],
5382
5383
5384
5385
5386
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[kv.dtype],
                    ).view(kv.shape)
                    k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1])
5387
5388
                    k, v = [x.squeeze(dim) for x in [k, v]]
                if qkv_group == 3:
5389
5390
                    q = cast_from_fp8(
                        q._data,
5391
                        fp8_meta["scaling_fwd"],
5392
5393
5394
5395
5396
5397
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    k = cast_from_fp8(
                        k._data,
5398
                        fp8_meta["scaling_fwd"],
5399
5400
5401
5402
5403
5404
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[k.dtype],
                    ).view(k.shape)
                    v = cast_from_fp8(
                        v._data,
5405
                        fp8_meta["scaling_fwd"],
5406
5407
5408
5409
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[v.dtype],
                    ).view(v.shape)
5410
5411
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
5412
5413
5414
5415
5416
5417
5418
5419
5420
5421
5422
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)

            fp8_tensors = (
                q_fp8,
                k_fp8,
                v_fp8,
                out_fp8,
5423
                fp8_meta["scaling_fwd"].scale.clone(),
5424
5425
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
5426
5427
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd(
5428
5429
5430
5431
5432
5433
5434
5435
5436
5437
5438
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
5439
5440
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
5441
5442
5443
5444
5445
5446
5447
5448
5449
5450
5451
5452
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5453
                window_size,
5454
5455
                rng_gen,
            )
5456
5457
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
5458

5459
        from .cpu_offload import CPUOffloadEnabled
5460

5461
        if CPUOffloadEnabled:
5462
            tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
5463
            qkv_layout = "sbhd_sbhd_sbhd"
5464
5465
5466
5467
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

5468
5469
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
5470
5471
5472
5473
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
5474
5475
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
5476
5477
5478
            *fp8_tensors,
            *aux_ctx_tensors,
        )
5479
        ctx.fp8_meta = fp8_meta
5480
5481
5482
5483
5484
5485
5486
5487
5488
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
5489
        ctx.window_size = window_size
5490
        ctx.fused_attention_backend = (
5491
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
5492
        )
5493
        ctx.use_FAv2_bwd = use_FAv2_bwd
5494
        ctx.deterministic = deterministic
5495

5496
        return out_ret
5497
5498
5499

    @staticmethod
    def backward(ctx, d_out):
5500
        if ctx.fp8_meta["recipe"].fp8_mha:
5501
5502
5503
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
5504
5505
5506
            d_out_f8tensor = d_out
            d_out = d_out._data

5507
        d_out = d_out.contiguous()
5508
5509
5510
5511
5512
5513
5514
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
5515
5516
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
5517
5518
5519
5520
5521
5522
5523
5524
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
5525
5526
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
5527
        if ctx.use_FAv2_bwd:
5528
            softmax_lse, rng_state = aux_ctx_tensors
5529
5530
5531
5532
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
5533
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
5534
            flash_attn_cuda_bwd(
5535
5536
5537
5538
5539
5540
5541
5542
5543
5544
5545
5546
5547
5548
5549
5550
5551
5552
5553
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
5554
            )
5555
5556
5557
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
5558
        else:
5559
5560
5561
5562
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
5563
5564
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
5565
5566
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
5567
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
5568
5569
5570
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
5571
5572
5573
5574
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
5575
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
5576
5577
5578
5579
5580
5581
5582
5583
5584
5585
5586
5587
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
5588
                        ctx.fused_attention_backend,
5589
5590
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
5591
5592
5593
5594
5595
5596
5597
5598
5599
5600
5601
5602
5603
5604
5605
5606
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5607
5608
                        ctx.window_size,
                        ctx.deterministic,
5609
                    )
5610

5611
                    if ctx.fp8_meta["recipe"].fp8_mha:
5612
5613
                        dq = Float8Tensor(
                            data=dq_fp8,
5614
5615
5616
5617
5618
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
5619
5620
5621
                        )
                        dk = Float8Tensor(
                            data=dk_fp8,
5622
5623
5624
5625
5626
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
5627
5628
5629
                        )
                        dv = Float8Tensor(
                            data=dv_fp8,
5630
5631
5632
5633
5634
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
5635
                        )
5636
                    else:
5637
                        qkv_group = len(ctx.qkv_layout.split("_"))
5638
                        if qkv_group == 1:
5639
5640
5641
5642
5643
5644
5645
5646
5647
5648
5649
5650
5651
                            dim = ctx.qkv_layout.find("3")
                            dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(
                                -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                            )
                            dqkv = cast_from_fp8(
                                dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1])
5652
5653
5654
5655
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
5656
5657
5658
5659
5660
5661
5662
5663
5664
5665
5666
5667
5668
5669
5670
5671
5672
5673
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
                            dkv = cast_from_fp8(
                                dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1])
5674
5675
5676
5677
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
5678
5679
5680
5681
5682
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
5683
5684
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
5685
5686
5687
5688
5689
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dk_fp8.shape)
5690
5691
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
5692
5693
5694
5695
5696
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dv_fp8.shape)
5697
5698
5699
5700
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
5701
5702
5703
5704
5705
5706
5707
5708
5709
5710
5711
5712
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
5713
                        ctx.fused_attention_backend,
5714
5715
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
5716
5717
5718
5719
5720
5721
5722
5723
5724
5725
5726
5727
5728
5729
5730
5731
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5732
5733
                        ctx.window_size,
                        ctx.deterministic,
5734
                    )
5735

5736
5737
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
5738
5739
5740
5741
5742
5743
5744
5745
5746
5747
5748
5749
5750
5751
5752
5753
5754
5755
5756
5757
5758
5759
5760
5761
5762
5763
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
5764
5765
                None,
                None,
5766
            )
5767
        # else, return (dqkv, dbias)
5768
5769
5770
5771
5772
5773
5774
5775
5776
5777
5778
5779
5780
5781
5782
5783
5784
5785
5786
5787
5788
5789
5790
5791
5792
5793
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
5794
5795
            None,
            None,
5796
        )
5797

5798

5799
class FusedAttention(torch.nn.Module):
5800
5801
5802
5803
5804
5805
5806
5807
5808
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

5809
5810
5811
5812
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
5813
    | attn_type     | self/cross              | self/cross                     |
5814
    | qkv_layout    |                         |                                |
5815
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
5816
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
5817
5818
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
5819
5820
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
5821
    | dropout       | yes                     | yes                            |
5822
5823
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
5824
    | output dtype  | fp16/bf16               | fp16/bf16                      |
5825
5826
5827
5828
    """

    def __init__(
        self,
5829
        softmax_scale: float,
5830
5831
5832
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
5833
5834
        layer_number: Optional[int] = None,
        deterministic: bool = False,
5835
5836
5837
    ) -> None:
        super().__init__()

5838
        self.softmax_scale = softmax_scale
5839
5840
5841
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
5842
5843
5844
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
5845
        self.layer_number = 1 if layer_number is None else layer_number
5846
        self.deterministic = deterministic
5847

5848
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
5849
5850
            """
            Temporarily remove fused_attention._extra_state as a missing key
5851
5852
5853
5854
            or an unexpected key when loading TransformerEngine checkpoints.
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
            phased out in TransformerEngine 2.0.
5855
5856
            """
            for key in incompatible_keys.missing_keys:
5857
                if "fused_attention._extra_state" in key:
5858
                    incompatible_keys.missing_keys.remove(key)
5859
5860
5861
5862
5863
5864
5865
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
5866

5867
5868
        self.register_load_state_dict_post_hook(remove_extra_states_check)

5869
    @no_torch_dynamo()
5870
5871
5872
5873
5874
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5875
5876
5877
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5878
5879
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
5880
5881
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5882
        attn_mask_type: str = "causal",
5883
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5884
        window_size: Optional[Tuple[int, int]] = None,
5885
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
5886
5887
5888
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
5889
5890
5891
        cp_group: Optional[dist_group_type] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
5892
        cp_comm_type: str = "p2p",
5893
5894
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
5895
5896
    ) -> torch.Tensor:
        """fused attention fprop"""
5897
5898
5899
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
5900
        assert (
5901
5902
5903
            (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
5904
        ), "FusedAttention only supports FP16 and BF16 data types."
5905
5906
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
5907
        ), "FusedAttention only supports CUDA tensors."
5908
5909
        assert (
            qkv_layout in QKVLayouts
5910
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
5911

5912
5913
        cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
        context_parallel = cp_size > 1
5914

5915
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
5916

5917
5918
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
5919
                batch_size, max_seqlen_q, max_seqlen_kv = (
5920
5921
5922
5923
5924
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
5925
                batch_size, max_seqlen_q, max_seqlen_kv = (
5926
5927
5928
5929
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
5930
5931
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
5932
            if "padding" in attn_mask_type:
5933
5934
                assert not context_parallel, "Padding mask not supported with context parallelism!"

5935
5936
5937
5938
5939
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
5940
                    if self.attention_type == "self":
5941
5942
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
5943
                    else:
5944
5945
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
5946
            else:
5947
5948
5949
5950
5951
5952
5953
5954
5955
5956
5957
5958
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
5959
5960
5961
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
5962
5963
5964
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
5965
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
5966
5967
5968
5969

        if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
5970
5971
5972

        qkv_dtype = TE_DType[query_layer.dtype]

5973
5974
5975
5976
5977
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
5978

5979
5980
5981
5982
5983
5984
5985
5986
5987
5988
5989
        if fp8:
            assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                " is required for FP8 attention!"
            )
            assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
            assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
                "Amax reduction across TP+CP group is necessary when using context parallelism with"
                " FP8!"
            )

5990
        if context_parallel:
5991
            assert (
5992
5993
                fp8
                or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
5994
5995
5996
5997
5998
5999
6000
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
6001
6002
6003
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
6004
6005
6006
6007
6008
6009
6010
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
6011
6012
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
6013
                    self.attention_dropout if self.training else 0.0,
6014
6015
6016
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
6017
                    cp_comm_type,
6018
                    softmax_scale=self.softmax_scale,
6019
                    qkv_format=qkv_format,
6020
                    attn_mask_type=attn_mask_type,
6021
6022
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
6023
                    deterministic=self.deterministic,
6024
                    use_fused_attention=True,
6025
                    window_size=window_size,
6026
6027
                    fp8=fp8,
                    fp8_meta=fp8_meta,
6028
6029
                )
        else:
6030
6031
6032
6033
6034
6035
6036
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
6037
6038
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
6039
6040
6041
6042
6043
6044
6045
6046
6047
6048
6049
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
6050
                    window_size,
6051
6052
6053
6054
6055
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
6056
                    self.deterministic,
6057
                )
6058

6059
6060
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
6061
6062


6063
class DotProductAttention(TransformerEngineBaseModule):
6064
6065
6066
6067
6068
6069
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

6070
        Argument :attr:`attention_mask` in the `forward` call is only used when
6071
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
6072
6073
6074

    .. warning::

6075
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
6076
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
6077
6078
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
6079
6080
6081
6082
6083

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
6084
6085
6086
    kv_channels : Union[int, Tuple[int, int]]
                the head size in key and value tensors. If the same, :attr:`kv_channels` can be
                an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
6087
6088
6089
6090
6091
6092
6093
6094
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
6095
6096
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
6097
    attn_mask_type: str, default = `causal`
6098
                   type of attention mask passed into softmax operation, options are "`no_mask`",
6099
6100
6101
6102
6103
6104
6105
6106
6107
6108
6109
6110
6111
6112
6113
6114
6115
6116
6117
6118
6119
6120
6121
6122
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
                   "`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
6123
6124
6125
6126
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
6127
6128
6129
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
6130
                be overridden by :attr:`window_size` in `forward` as well.
6131
6132
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
6133
6134
6135
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
6136
6137
6138
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
6139
               `h` the number of heads, `d` head size, and `t` the total number of tokens
6140
6141
6142
6143
6144
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
6145
               For that, please use `get_qkv_layout` to gain the layout information.
6146
6147
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
6148
                `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
6149
6150
6151
6152
6153
6154
6155
6156
6157

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
6158
6159
6160
6161
6162
6163
6164
6165
6166
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
6167
6168
6169
    cp_comm_type : str
                  inter-gpu communication type for context parallelism.
                  Can be "p2p" or "all_gather".
6170
6171
6172
6173
6174
    """

    def __init__(
        self,
        num_attention_heads: int,
6175
        kv_channels: Union[int, Tuple[int, int]],
6176
        num_gqa_groups: Optional[int] = None,
6177
        attention_dropout: float = 0.0,
6178
        qkv_format: str = "sbhd",
6179
        attn_mask_type: str = "causal",
6180
        window_size: Optional[Tuple[int, int]] = None,
6181
6182
6183
6184
6185
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
6186
        attention_type: str = "self",
6187
        cp_group: Optional[dist_group_type] = None,
6188
        cp_global_ranks: List[int] = None,
6189
        cp_stream: torch.cuda.Stream = None,
6190
        cp_comm_type: str = "p2p",
6191
        softmax_scale: Optional[float] = None,
6192
6193
6194
    ) -> None:
        super().__init__()

6195
        self.logger = logging.getLogger("DotProductAttention")
6196
6197
6198
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
6199
        self.qkv_format = qkv_format
6200
        attn_mask_type = attn_mask_type.replace(",", "_")
6201
6202
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
6203
        self.attn_mask_type = attn_mask_type
6204
        self.window_size = check_set_window_size(attn_mask_type, window_size)
6205
6206
6207
6208
6209
6210
6211
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
6212
        self.get_rng_state_tracker = get_rng_state_tracker
6213
        self.num_attention_heads = num_attention_heads
6214
        self.layer_number = 1 if layer_number is None else layer_number
6215
6216
6217
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
6218
        self.cp_comm_type = cp_comm_type
6219

6220
6221
6222
6223
6224
6225
        self.hidden_size_per_attention_head_k = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[0]
        )
        self.hidden_size_per_attention_head_v = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[1]
        )
6226

6227
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
6228
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
6229

6230
6231
6232
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
6233

6234
        self.rng_states_tracker = None
6235
6236
6237
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
6238
6239
6240
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
6241

6242
        if softmax_scale is None:
6243
6244
6245
            softmax_scale = 1.0 / math.sqrt(
                kv_channels if isinstance(kv_channels, int) else kv_channels[0]
            )
6246

6247
6248
6249
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
6250
        )
6251
6252
6253
6254
6255
6256
6257
6258
6259
6260
6261
6262
6263
6264
6265
6266
6267
6268
6269
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
6270

6271
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
6272
6273
6274
6275

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

6276
6277
6278
6279
6280
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

6281
6282
6283
6284
6285
6286
6287
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
6288

6289
        # Instantiating three types since use of flash-attn and FusedAttention
6290
        # might be ruled out due to forward inputs.
6291
6292
6293
6294
6295
6296
6297
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
6298

6299
        self.unfused_attention = UnfusedDotProductAttention(
6300
6301
6302
6303
            softmax_scale,
            attention_type=attention_type,
            **attn_kwargs,
            layer_number=layer_number,
6304
        )
6305

6306
6307
6308
6309
6310
6311
6312
6313
6314
6315
6316
6317
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
            when loading older TransformerEngine checkpoints. Will phase out
            this hook in TransformerEngine 2.0.
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

6318
6319
6320
6321
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
6322
        **forward_kwargs: Dict[str, Any],
6323
6324
6325
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

6326
6327
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
6328
6329
6330

        hidden_states = checkpoint(
            custom_forward,
6331
6332
6333
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
6334
            *forward_args,
6335
            **forward_kwargs,
6336
6337
6338
6339
        )

        return hidden_states

6340
6341
6342
6343
6344
    def set_context_parallel_group(
        self,
        cp_group: Union[dist_group_type, None],
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
6345
        cp_comm_type: str = "p2p",
6346
    ) -> None:
6347
6348
6349
6350
6351
6352
6353
6354
6355
6356
6357
6358
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
6359
6360
6361
        cp_comm_type : str
                      inter-gpu communication type for context parallelism.
                      Can be "p2p" or "all_gather".
6362
        """
6363
6364
6365
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
6366
        self.cp_comm_type = cp_comm_type
6367

6368
    @no_torch_dynamo(recursive=False)
6369
6370
6371
6372
6373
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
6374
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6375
6376
6377
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
6378
6379
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
6380
6381
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
6382
        attn_mask_type: Optional[str] = None,
6383
        window_size: Optional[Tuple[int, int]] = None,
6384
        checkpoint_core_attention: bool = False,
6385
6386
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
6387
        alibi_slopes: Optional[torch.Tensor] = None,
6388
        fast_zero_fill: bool = True,
6389
        inference_params: Optional[InferenceParams] = None,
6390
        is_first_microbatch: Optional[bool] = None,
6391
6392
6393
6394
6395
6396
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

6397
6398
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
6399

6400
6401
        .. note::

6402
6403
6404
6405
6406
6407
6408
6409
6410
6411
6412
6413
6414
6415
6416
6417
6418
6419
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
6420
6421
6422
6423
6424
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
6425

6426
6427
6428
6429
6430
6431
6432
6433
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
6434
6435
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
6436
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
6437
6438
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
6439
6440
6441
6442
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
6443
6444
6445
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
6446
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
6447
6448
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
6449
6450
6451
6452
6453
6454
6455
6456
6457
6458
6459
6460
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
6461
6462
6463
6464
6465
6466
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
6467
6468
6469
6470
6471
6472
6473
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
6474
        window_size: Optional[Tuple[int, int]], default = `None`
6475
                    Sliding window size for local attention.
6476
6477
6478
6479
6480
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
6481
        core_attention_bias_type: str, default = `no_bias`
6482
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
6483
        core_attention_bias: Optional[torch.Tensor], default = `None`
6484
6485
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
6486
6487
6488
6489
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
6490
        fast_zero_fill: bool, default = `True`
6491
                    Whether to use the fast path to set output tensors to 0 or not.
6492
6493
6494
6495
6496
6497
6498
6499
6500
6501
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
6502
6503
6504
6505
6506
6507
6508
6509
6510
6511
6512
6513
6514
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
6515
        """
6516
6517
6518
6519
6520
6521
6522
6523
6524
6525
6526
        with self.prepare_forward(
            query_layer,
            is_first_microbatch,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:

            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
6527
                        self.logger.warning(
6528
6529
6530
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
6531
6532
6533
6534
6535
6536
6537
6538
6539
6540
6541

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
6542

6543
6544
6545
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
6546
6547
6548
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
6549
6550
6551
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
6552
6553
6554
6555
6556
6557
6558
6559
            assert (
                key_layer.shape[-1] == self.hidden_size_per_attention_head_k
            ), f"Keys have head_dim = {key_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
            assert (
                value_layer.shape[-1] == self.hidden_size_per_attention_head_v
            ), f"Values have head_dim = {value_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
6560

6561
6562
6563
6564
6565
6566
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
6567
            assert (
6568
6569
6570
6571
6572
6573
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
6574

6575
6576
6577
6578
            if window_size is None:
                window_size = self.window_size
            window_size = check_set_window_size(attn_mask_type, window_size)

6579
6580
6581
6582
6583
6584
6585
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
6586

6587
6588
            if qkv_format is None:
                qkv_format = self.qkv_format
6589

6590
6591
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
6592

6593
6594
6595
6596
6597
                # convert causal to causal_bottom_right in inference when KV-caching is in use
                # so users can run with the same attn_mask_type for training and inference
                if attn_mask_type in ["causal", "padding_causal"]:
                    attn_mask_type = attn_mask_type + "_bottom_right"

6598
6599
6600
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
6601

6602
6603
6604
6605
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
6606

6607
6608
6609
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
6610

6611
6612
6613
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
6614

6615
6616
6617
6618
6619
6620
6621
6622
6623
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
6624

6625
6626
6627
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
6628

6629
6630
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
6631
6632

            assert (
6633
6634
6635
6636
6637
6638
6639
6640
6641
6642
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
6643
                assert all(
6644
6645
6646
6647
6648
6649
6650
6651
6652
6653
6654
6655
6656
6657
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
                if max_seqlen_q is None:
6658
6659
6660
6661
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
6662
                    max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
6663
                if max_seqlen_kv is None:
6664
6665
6666
6667
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
6668
                    max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
6669
                batch_size = len(cu_seqlens_q) - 1
6670

6671
6672
6673
            cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group)
            context_parallel = cp_size > 1

6674
            if qkv_format in ["sbhd", "bshd"]:
6675
                assert all(
6676
6677
6678
6679
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
6680
                    batch_size = query_layer.shape[1]
6681
6682
                if qkv_format == "bshd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
6683
                    batch_size = query_layer.shape[0]
6684
6685
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
6686
6687
6688
6689
6690
6691
6692
6693
6694
6695
6696
6697
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                        the sequence dimention in 'query_layer'!"""
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                        the sequence dimention in 'key_layer' and 'value_layer'!"""
6698
6699
6700
6701
6702
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
6703
                        if self.attention_type == "self":
6704
6705
6706
6707
6708
6709
6710
6711
6712
6713
6714
6715
6716
6717
6718
6719
                            cu_seqlens_q = get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                    else:
                        cu_seqlens_q = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
                        cu_seqlens_kv = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
6720

6721
6722
6723
6724
6725
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
6726
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
6727
6728
6729
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
6730
                qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
6731
6732
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
6733

6734
6735
6736
6737
6738
6739
6740
6741
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
6742
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
6743
6744
6745
6746
6747
6748
6749
6750
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
6751
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
6752
6753
6754
6755
6756
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

6757
6758
            core_attention_bias_shape = None
            if core_attention_bias is not None:
6759
                if (
6760
6761
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
6762
                ):
6763
6764
6765
6766
6767
6768
6769
6770
6771
6772
6773
6774
6775
6776
6777
6778
6779
6780
6781
6782
6783
6784
6785
6786
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

            pad_between_seqs = (
                cu_seqlens_q_padded is not None
                and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
            ) or (
                cu_seqlens_kv_padded is not None
                and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
            )
6787

6788
            attention_params = AttentionParams(
6789
6790
6791
6792
6793
6794
6795
6796
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
6797
6798
                head_dim_qk=query_layer.shape[-1],
                head_dim_v=value_layer.shape[-1],
6799
6800
6801
6802
6803
6804
6805
6806
6807
6808
6809
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
6810
6811
                deterministic=self.deterministic,
                is_training=self.training,
6812
6813
6814
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
6815
6816
6817
6818
6819
6820
6821
6822
6823
6824
6825
6826
6827
6828
6829
6830
6831
6832
6833
6834
6835
            global _attention_backends
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
                ) = get_attention_backend(attention_params)
                if use_flash_attention:
                    self.logger.info("Running with FlashAttention backend")
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
6836
                    )
6837
6838
6839
6840
6841
6842
6843
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
6844

6845
6846
6847
6848
6849
6850
6851
6852
6853
6854
6855
6856
6857
6858
6859
6860
6861
6862
6863
6864
6865
6866
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
6867
                    cp_comm_type=self.cp_comm_type,
6868
6869
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
6870
                )
6871

6872
            if use_fused_attention:
6873
6874
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
6875
6876
6877
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
6878
6879
6880
6881
6882
6883
6884
                    fu_core_attention_bias_type = "post_scale_bias"
                    _, fu_core_attention_bias = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
6885
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
6886
                    )
6887
6888
6889
6890
6891
6892
6893
6894
6895
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
6896
6897
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
6898
6899
6900
6901
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
6902
                        window_size=window_size,
6903
6904
6905
6906
6907
6908
6909
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
6910
                        cp_comm_type=self.cp_comm_type,
6911
6912
6913
6914
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
                    )
                return self.fused_attention(
6915
6916
6917
6918
6919
6920
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
6921
6922
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
6923
6924
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
6925
6926
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
6927
                    window_size=window_size,
6928
                    fused_attention_backend=fused_attention_backend,
6929
6930
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
6931
6932
6933
6934
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
6935
                    cp_comm_type=self.cp_comm_type,
6936
6937
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
6938
                )
6939

6940
            from .cpu_offload import CPUOffloadEnabled
6941

6942
6943
6944
6945
6946
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
6947

6948
            if use_unfused_attention:
6949
6950
6951
6952
6953
6954
                if window_size is not None and (
                    window_size[0] != -1 or window_size[1] not in [-1, 0]
                ):
                    attn_mask_type, attention_mask = get_swa_mask(
                        window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
                    )
6955
6956
6957
6958
6959
6960
6961
6962
6963
6964
6965
6966
6967
6968
6969
6970
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
6971
6972
6973
                    query_layer,
                    key_layer,
                    value_layer,
6974
6975
6976
6977
6978
6979
6980
6981
6982
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
6983

6984
            raise Exception("No dot product attention support for the provided inputs!")
6985
6986


6987
6988
6989
6990
6991
6992
6993
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

6994
6995
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
6996

6997
6998
6999
7000
7001
7002
7003
7004
7005
7006
7007
7008
7009
7010
7011
7012
7013
7014
7015
7016
7017
7018
7019
7020
7021
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
7022
7023
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
7024
                   default = `causal`
7025
7026
7027
7028
7029
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
7030
7031
7032
7033
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
7034
7035
7036
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
7037
                be overridden by :attr:`window_size` in `forward` as well.
7038
7039
7040
7041
7042
7043
7044
7045
7046
7047
7048
7049
7050
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
7051
7052
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
7053
7054
7055
7056
7057
7058
7059
7060
7061
7062
7063
7064
7065
7066
7067
7068
7069
7070
7071
7072
7073
7074
7075
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
7076
7077
7078
7079
7080
7081
7082
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
7083
            For that, please use `get_qkv_layout` to gain the layout information.
7084
7085
7086
7087
7088
7089
7090
7091
7092
7093
7094
7095
7096
7097
7098
7099
7100
7101
7102
7103
7104
7105
7106
7107
7108
7109
7110
7111
7112
7113
7114
7115
7116
7117
7118
7119
7120
7121
7122
7123

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
7124
7125
7126
7127
7128
7129
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
7130
7131
7132
7133
7134
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
7135
        layer_number: Optional[int] = None,
7136
        attn_mask_type: str = "causal",
7137
        window_size: Optional[Tuple[int, int]] = None,
7138
7139
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
7140
        num_gqa_groups: Optional[int] = None,
7141
7142
7143
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
7144
        params_dtype: Optional[torch.dtype] = None,
7145
        return_bias: bool = False,
7146
7147
7148
7149
7150
7151
7152
7153
7154
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
7155
        ub_overlap_rs_dgrad: bool = False,
7156
7157
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
7158
        bias: bool = True,
7159
        normalization: str = "LayerNorm",
7160
        device: Union[torch.device, str] = "cuda",
7161
        qkv_format: str = "sbhd",
7162
7163
    ) -> None:
        super().__init__()
7164

7165
        self.qkv_format = qkv_format
7166
        self.attn_mask_type = attn_mask_type
7167
        self.window_size = check_set_window_size(attn_mask_type, window_size)
7168
        self.layer_number = layer_number
7169
7170
7171
7172
7173
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
7174
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
7175
        self.num_attention_heads = num_attention_heads
7176
7177
7178
7179
7180
7181
7182
7183
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
7184
7185
7186
7187
7188

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

7189
7190
7191
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
7192
7193
7194
7195
7196
7197

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
7198
7199
7200
7201
7202
7203
7204
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
7205
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
7206
7207
7208
7209

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
7210
7211
7212
7213
7214
7215
7216

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
7217
            "params_dtype": self.params_dtype,
7218
            "device": device,
7219
7220
7221
7222
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
7223
        if self.attention_type == "self":
7224
7225
            parameters_split = None
            if not fuse_qkv_params:
7226
7227
7228
7229
7230
7231
7232
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
7233
7234
7235
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
7236
                    self.hidden_size_q + 2 * self.hidden_size_kv,
7237
7238
7239
7240
7241
7242
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
7243
                    parameters_split=parameters_split,
7244
7245
7246
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
7247
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
7248
                    ub_overlap_ag=ub_overlap_ag,
7249
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
7250
                    ub_name="qkv",
7251
7252
7253
7254
7255
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
7256
                    self.hidden_size_q + 2 * self.hidden_size_kv,
7257
7258
7259
7260
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
7261
                    parameters_split=parameters_split,
7262
7263
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
7264
        elif self.attention_type == "cross":
7265
7266
7267
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
7268
                    self.hidden_size_q,
7269
7270
7271
7272
7273
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
7274
                    parameters_split=("query",) if not fuse_qkv_params else None,
7275
7276
7277
7278
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
7279
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
7280
                    ub_overlap_ag=ub_overlap_ag,
7281
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
7282
                    ub_name="qkv",
7283
7284
7285
7286
7287
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
7288
                    self.hidden_size_q,
7289
7290
7291
7292
7293
7294
7295
7296
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
7297
                2 * self.hidden_size_kv,
7298
7299
7300
7301
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
7302
                parameters_split=("key", "value") if not fuse_qkv_params else None,
7303
7304
7305
7306
7307
7308
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
7309
            self.hidden_size_per_attention_head,
7310
7311
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
7312
            qkv_format=self.qkv_format,
7313
7314
7315
7316
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
7317
            layer_number=self.layer_number,
7318
            attention_type=self.attention_type,
7319
7320
7321
7322
        )

        # Linear
        self.proj = Linear(
7323
            self.hidden_size_q,
7324
7325
7326
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
7327
            return_bias=return_bias,
7328
            parallel_mode="row" if set_parallel_mode else None,
7329
7330
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
7331
            ub_name="proj",
7332
7333
7334
7335
            **common_gemm_kwargs,
        )

    def _allocate_memory(
7336
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
7337
7338
7339
7340
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
7341
            self.num_gqa_groups_per_partition,
7342
            self.hidden_size_per_attention_head,
7343
            dtype=dtype,
7344
7345
7346
7347
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
7348
7349
7350
7351
7352
7353
7354
7355
7356
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
7357
7358
        self.tp_group = tp_group

7359
    def set_context_parallel_group(
7360
7361
        self,
        cp_group: Union[dist_group_type, None],
7362
        cp_global_ranks: List[int],
7363
        cp_stream: torch.cuda.Stream,
7364
        cp_comm_type: str = "p2p",
7365
    ) -> None:
7366
7367
7368
7369
7370
7371
7372
7373
7374
7375
7376
7377
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
7378
7379
7380
        cp_comm_type : str
                      inter-gpu communication type for context parallelism.
                      Can be "p2p" or "all_gather".
7381
        """
7382
7383
7384
7385
7386
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
7387
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
7388

7389
7390
7391
    def forward(
        self,
        hidden_states: torch.Tensor,
7392
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
7393
        encoder_output: Optional[torch.Tensor] = None,
7394
        attn_mask_type: Optional[str] = None,
7395
        window_size: Optional[Tuple[int, int]] = None,
7396
7397
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
7398
        inference_params: Optional[InferenceParams] = None,
7399
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
7400
7401
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
7402
        alibi_slopes: Optional[torch.Tensor] = None,
7403
7404
7405
7406
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
7407
        fast_zero_fill: bool = True,
7408
    ) -> Tuple[Union[torch.Tensor, None], ...]:
7409
7410
7411
7412
7413
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

7414
7415
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
7416
7417
7418
7419
7420

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
7421
7422
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
7423
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
7424
7425
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
7426
7427
7428
7429
7430
7431
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
7432
                       default = `None`
7433
7434
7435
7436
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
7437
7438
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
7439
7440
7441
7442
7443
7444
7445
7446
7447
7448
7449
7450
7451
7452
7453
7454
7455
7456
7457
7458
7459
7460
7461
7462
7463
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
7464
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
7465
        core_attention_bias: Optional[torch.Tensor], default = `None`
7466
7467
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
7468
7469
7470
7471
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
7472
7473
7474
7475
7476
7477
7478
7479
7480
7481
7482
7483
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
7484
7485
7486
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
7487
7488
        # hidden_states: [sq, b, h]

7489
        if attn_mask_type is None:
7490
            attn_mask_type = self.attn_mask_type
7491
7492
        if window_size is None:
            window_size = self.window_size
7493
        window_size = check_set_window_size(attn_mask_type, window_size)
7494

7495
        if "padding" in attn_mask_type and attention_mask is not None:
7496
            for i, _ in enumerate(attention_mask):
7497
7498
7499
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
7500

7501
7502
7503
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
7504

7505
        # =================================================
7506
        # Pre-allocate memory for key-values for inference
7507
7508
7509
        # =================================================

        if inference_params and self.layer_number is not None:
7510
7511
7512
            assert (
                self.qkv_format != "thd"
            ), "qkv_format == thd is not supported for an inference with KV-cache!"
7513
            if self.layer_number not in inference_params.key_value_memory_dict:
7514
                inf_max_seq_len = inference_params.max_sequence_length
7515
7516
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
7517
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
7518
7519
                )
                inference_value_memory = self._allocate_memory(
7520
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
7521
7522
7523
7524
7525
7526
7527
7528
7529
7530
7531
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

7532
        # ======================
7533
        # Query, Key, and Value
7534
        # ======================
7535

cyanguwa's avatar
cyanguwa committed
7536
7537
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
7538
7539
7540
7541
7542
7543
7544
7545
7546
7547
7548
7549
7550
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
7551
                    is_first_module_in_mha=True,  # specific to FP8 MHA
7552
7553
                )

7554
7555
7556
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
7557
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
7558
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
7559
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
7560
7561
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
7562
7563
7564
7565
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
7566
7567
7568
7569
7570
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
7571
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
7572
7573
7574
                )
                # split along third last dimension
                split_dim = -3
7575
7576
7577

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
7578
7579
7580
7581
7582
7583
7584
7585
7586
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
7587
                )
7588
            else:
cyanguwa's avatar
cyanguwa committed
7589
                query_layer, key_layer, value_layer = torch.split(
7590
7591
7592
7593
                    mixed_x_layer,
                    (num_queries_per_key_value, 1, 1),
                    dim=split_dim,
                )
cyanguwa's avatar
cyanguwa committed
7594

7595
7596
7597
7598
7599
7600
7601
7602
7603
7604
7605
7606
            if self.qkv_format == "thd":
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
            else:
                # query: -> [sq, b, np, hn]
                # key, value: -> [sq, b, ng, hn]
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
cyanguwa's avatar
cyanguwa committed
7607
7608
        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
7609
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
7610
                encoder_output,
7611
                is_first_microbatch=is_first_microbatch,
7612
                is_first_module_in_mha=True,  # specific to FP8 MHA
7613
7614
7615
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
7616
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
7617
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
7618
                    self.num_gqa_groups_per_partition,
7619
7620
7621
7622
7623
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
7624
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
7625
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
7626
                    2 * self.num_gqa_groups_per_partition,
7627
7628
7629
7630
7631
7632
7633
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
7634
7635
7636
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
7637
7638
7639
                    mixed_kv_layer,
                    split_dim,
                    mixed_kv_layer.shape[split_dim] // 2,
cyanguwa's avatar
cyanguwa committed
7640
                )
7641
            else:
cyanguwa's avatar
cyanguwa committed
7642
                key_layer, value_layer = torch.split(
7643
7644
7645
                    mixed_kv_layer,
                    mixed_kv_layer.shape[split_dim] // 2,
                    dim=split_dim,
cyanguwa's avatar
cyanguwa committed
7646
                )
7647
7648
7649
7650
7651
7652
7653
7654
7655
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
7656
7657
7658
7659
7660
7661
7662
7663
7664
7665
7666
7667
7668
7669
7670

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
7671
                    is_first_module_in_mha=True,  # specific to FP8 MHA
7672
7673
7674
7675
7676
7677
7678
7679
7680
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

7681
7682
7683
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
7684

7685
        if rotary_pos_emb is not None:
7686
7687
7688
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
7689
            # duplicate the pos_emb for self attention
7690
            if not isinstance(rotary_pos_emb, tuple):
7691
                rotary_pos_emb = (rotary_pos_emb,) * 2
7692
7693

            q_pos_emb, k_pos_emb = rotary_pos_emb
7694
7695
7696
7697
7698
7699
7700
7701
7702
7703
7704
7705
7706
7707

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

7708
7709
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
7710

7711
7712
7713
7714
        # ===========================
        # Core attention computation
        # ===========================

7715
7716
7717
7718
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
7719
            qkv_format=self.qkv_format,
7720
7721
7722
7723
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
7724
7725
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
7726
            window_size=window_size,
7727
7728
7729
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
7730
            alibi_slopes=alibi_slopes,
7731
            fast_zero_fill=fast_zero_fill,
7732
            inference_params=inference_params,
7733
7734
        )

7735
        # ===================
7736
        # Output. [sq, b, h]
7737
        # ===================
7738

7739
        projection_output = self.proj(
7740
7741
            context_layer,
            is_first_microbatch=is_first_microbatch,
7742
7743
        )

7744
7745
7746
7747
7748
7749
7750
7751
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
7752
        if self.input_layernorm and self.return_layernorm_output:
7753
7754
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]