attention.py 404 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
from importlib.metadata import PackageNotFoundError
10
import math
11
import os
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
import warnings
14
import logging
15
import functools
16

17
from dataclasses import dataclass, fields
cyanguwa's avatar
cyanguwa committed
18
import numpy as np
19
from packaging.version import Version as PkgVersion
20
21

import torch
22
import torch.nn.functional as F
23

24
import transformer_engine_torch as tex
25
26
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
27
28
29
30
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
31
32
33
34
35
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
36
37
    fused_attn_fwd,
    fused_attn_bwd,
38
39
40
41
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
42
43
44
45
46
47
48
49
50
51
52
53
54
    META_QKV,
    META_DQKV,
    META_O,
    META_DO,
    META_S,
    META_DP,
    META_O_CP,
    META_DQKV_CP,
)
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    get_fp8_te_dtype,
    get_fp8_torch_dtype,
55
)
56
from transformer_engine.pytorch.float8_tensor import Float8Tensor
57
from transformer_engine.pytorch.module import LayerNormLinear, Linear
58
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
59
60
61
62
63
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
64
    get_default_init_method,
65
66
67
68
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
69
    AttnBiasTypes,
70
    QKVLayouts,
71
    dist_group_type,
72
    TE_DType,
73
74
75
76
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
77
    get_distributed_rank,
78
    checkpoint,
79
80
81
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
82
83
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
84
85
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
86
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
87
88
from transformer_engine.pytorch.graph import is_graph_capturing

89

90
91
92
93
94
95
96
97
98
99
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
100
101
102
103
104
105
106
107
108
109
fa_logger = logging.getLogger()
fa_logger.setLevel(_log_level)
if not fa_logger.hasHandlers():
    fa_logger.addHandler(_stream_handler)


@functools.lru_cache(maxsize=None)
def _get_supported_versions(version_min, version_max):
    return ">= " + str(version_min) + ", " + "<= " + str(version_max)

110

111
112
113
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
114
115
116
117
118

# Detect flash-attn v2 in the environment
_flash_attn_is_installed = False
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
119
_flash_attn_max_version = PkgVersion("2.6.3")
120
121
122
123
124
125
126
_flash_attn_2_plus = False
_flash_attn_2_1_plus = False
_flash_attn_2_3_plus = False
_flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False
127

128
flash_attn_cuda_bwd = None
129
130
flash_attn_func = None
flash_attn_varlen_func = None
131
132
133
134
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
135

136
137
138
try:
    _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
139
    if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
140
141
142
143
144
145
        fa_logger.debug(
            "flash-attn v2 is not installed. To use, please install it by"
            """ "pip install flash-attn".""",
        )
else:
    if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
146
        from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
147
        from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
148
149
        from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
        from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd
150
        from flash_attn.flash_attn_interface import (
151
            _flash_attn_varlen_forward as _flash_attn_varlen_fwd,
152
153
        )
        from flash_attn.flash_attn_interface import (
154
            _flash_attn_varlen_backward as _flash_attn_varlen_bwd,
155
156
157
158
159
160
161
162
163
164
        )

        _flash_attn_is_installed = True
        _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
        _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
        _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
        _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
        _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
        _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
        _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
165
166
167
    elif (
        torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN
    ):
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        fa_logger.warning(
            "Supported flash-attn versions are %s. Found flash-attn %s.",
            _get_supported_versions(
                _flash_attn_version_required,
                _flash_attn_max_version,
            ),
            _flash_attn_version,
        )

# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
_flash_attn_3_is_installed = False
_flash_attn_3_version = PkgVersion("0")
_flash_attn_3_0_0_beta = False
183
_use_flash_attn_3 = False
184
185
186
187
188
_flash_attn_3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
189
try:
190
    _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
191
except PackageNotFoundError:
192
    if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
193
194
        fa_logger.debug(
            "flash-attn v3 is not installed. To use, please install it by \n%s",
195
            _flash_attn_3_installation_steps,
196
        )
197
198
199
200
201
else:
    from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
    from flashattn_hopper.flash_attn_interface import (
        flash_attn_varlen_func as flash_attn_varlen_func_v3,
    )
202
203
    from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
    from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
204
    from flashattn_hopper.flash_attn_interface import (
205
        _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
206
207
    )
    from flashattn_hopper.flash_attn_interface import (
208
        _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
209
    )
210

211
212
    _flash_attn_3_is_installed = True
    _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
213
    _use_flash_attn_3 = True
214

215
216
217
218
219
220
221
_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
222
}
223
224


225
226
@dataclass(eq=True)
class AttentionParams:
227
    """
228
    Attention parameters used to determine which backend to be used.
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    Parameters
    ----------
    qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
        Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
    qkv_dtype: torch.dtype, default = `torch.bfloat16`
        Data type of query/key/value tensors.
    qkv_layout: str, default = "sbh3d"
        Query/key/value tensor memory layout.
    batch_size: int, default = 1
        Batch size.
    num_heads: int, default = 16
        Number of attention heads in the query tensor.
    num_gqa_groups: int, default = 16
        Number of attention heads in key and value tensors.
    max_seqlen_q: int, default = 128
        Maximum sequence length of the query tensor.
    max_seqlen_kv: int, default = 128
        Maximum sequence length of the key and value tensors.
248
249
250
251
    head_dim_qk: int, default = 64
        The size of each attention head in query and key tensors.
    head_dim_v: int, default = 64
        The size of each attention head in the value tensor.
252
253
254
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
        `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
255
    window_size: Tuple[int, int], default = None
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        Sliding window attention size.
    alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
        Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
    core_attention_bias_type: str, default = `no_bias`
        Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
    core_attention_bias_shape: str, default = `1hss`
        Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
    core_attention_bias_requires_grad: bool, default = `True`
        Whether attention bias requires gradient.
    pad_between_seqs: bool, default = `False`
        Whether there is padding between sequences in a batch.
        This only applies to `qkv_format=thd`.
    attention_dropout: float, default = 0.0
        Attention dropout.
    context_parallel: bool, default = `False`
        Whether context parallelism is used or not.
    deterministic: bool, default = `False`
        Whether to run `DotProductAttention` with determinism or not.
274
275
    is_training: bool, default = `True`
        Whether in training mode (`True`) or inference mode (`False`)
276
277
278
279
    fp8: bool, default = `False`
        Whether `DotProductAttention` is in an `fp8_autocast` region.
    fp8_meta: Optional[Dict[str Any]], default = `None`
        The FP8 metadata tensor of `DotProductAttention`.
280
281
282
283
284
285
286
287
288
289
    """

    qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
    qkv_dtype: torch.dtype = torch.bfloat16
    qkv_layout: str = "sbh3d"
    batch_size: int = 1
    num_heads: int = 16
    num_gqa_groups: int = 16
    max_seqlen_q: int = 128
    max_seqlen_kv: int = 128
290
291
    head_dim_qk: int = 64
    head_dim_v: int = 64
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    attn_mask_type: str = "no_mask"
    window_size: Union[Tuple[int, int], None] = None
    alibi_slopes_shape: Union[torch.Size, List, None] = None
    core_attention_bias_type: str = "no_bias"
    core_attention_bias_shape: str = "1hss"
    core_attention_bias_requires_grad: bool = True
    pad_between_seqs: bool = False
    attention_dropout: float = 0.0
    context_parallel: bool = False
    deterministic: bool = False
    is_training: bool = True
    fp8: bool = False
    fp8_meta: Union[Dict[str, Any], None] = None

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    def __eq__(self, other):
        """
        Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared,
        since all other entries of fp8_meta are unused in get_attention_backend.
        """
        if not isinstance(other, self.__class__):
            return NotImplemented
        for field in fields(self):
            fname = field.name
            sf = getattr(self, fname)
            of = getattr(other, fname)
            if fname != "fp8_meta":
                if sf != of:
                    return False
            elif sf["recipe"] != of["recipe"]:
                return False
        return True

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


340
341
342
343
344
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
    """Make tensor contiguous if final stride is not 1."""
    return tensor.contiguous() if tensor.stride(-1) != 1 else tensor


345
346
347
348
349
350
351
352
353
def get_attention_backend(
    attention_params: AttentionParams = None,
):
    """
    Select the appropriate attention backend/sub-backend based on user input and runtime environment.

    Parameters
    ----------
    See `AttentionParams`.
354
355
356
357
358
359
360

    Returns
    ----------
    use_flash_attention: bool
        Whether the `FlashAttention` backend has been selected.
    use_fused_attention: bool
        Whether the `FusedAttention` backend has been selected.
361
362
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend
        If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
363
364
365
366
367
368
    use_unfused_attention: bool
        Whether the `UnfusedDotProductAttention` backend has been selected.
    available_backends: List[bool]
        All available backends that could support the provided input. A list of Booleans
        in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
    """
369
370
371
372
373
374
375
376
    qkv_type = attention_params.qkv_type
    qkv_dtype = attention_params.qkv_dtype
    qkv_layout = attention_params.qkv_layout
    batch_size = attention_params.batch_size
    num_heads = attention_params.num_heads
    num_gqa_groups = attention_params.num_gqa_groups
    max_seqlen_q = attention_params.max_seqlen_q
    max_seqlen_kv = attention_params.max_seqlen_kv
377
378
    head_dim_qk = attention_params.head_dim_qk
    head_dim_v = attention_params.head_dim_v
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    attn_mask_type = attention_params.attn_mask_type
    window_size = attention_params.window_size
    alibi_slopes_shape = attention_params.alibi_slopes_shape
    core_attention_bias_type = attention_params.core_attention_bias_type
    core_attention_bias_shape = attention_params.core_attention_bias_shape
    core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
    pad_between_seqs = attention_params.pad_between_seqs
    attention_dropout = attention_params.attention_dropout
    context_parallel = attention_params.context_parallel
    deterministic = attention_params.deterministic
    is_training = attention_params.is_training
    fp8 = attention_params.fp8
    fp8_meta = attention_params.fp8_meta

    # Run config
394
    logger = logging.getLogger("DotProductAttention")
395
396
397
    logger.setLevel(_log_level)
    if not logger.hasHandlers():
        logger.addHandler(_stream_handler)
398
399
400
401
402
    device_compute_capability = get_device_compute_capability()
    cudnn_version = get_cudnn_version()
    run_config = {
        "transformer_engine_version": te.__version__,
        "compute_capability": "sm"
403
        + str(10 * device_compute_capability[0] + device_compute_capability[1]),
404
405
406
407
408
409
        "flash_attn_version": (
            str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
        ),
        "flash_attn_3_version": (
            str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed"
        ),
410
411
412
413
414
415
416
417
418
        "cudnn_version": ".".join([str(i) for i in cudnn_version]),
    }
    attention_params_dict = {
        field.name: getattr(attention_params, field.name) for field in fields(attention_params)
    }
    run_config.update(attention_params_dict)
    if fp8:
        run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
    logger.debug("Running with config=%s", run_config)
419

420
421
422
423
424
425
    # The following sections check if `FlashAttention` supports the provided attention params,
    # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
    # necessary for performance/functionality, a warning will be issued to prompt users to
    # install an appropriate FA version.
    global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3

426
    # Filter: Environment variables
427
428
429
430
    use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
    use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
    use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
    if not use_flash_attention and _flash_attn_is_installed:
431
432
433
434
435
436
437
438
        logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
    if not use_fused_attention:
        logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
    if not use_unfused_attention:
        logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")

    # Filter: ONNX mode
    if is_in_onnx_export_mode():
439
        if use_flash_attention and _flash_attn_is_installed:
440
441
442
443
444
445
446
447
            logger.debug("Disabling FlashAttention due to ONNX mode")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention due to ONNX mode")
        use_fused_attention = False

    # Filter: Compute capability
    if device_compute_capability < (8, 0):
448
        if use_flash_attention and _flash_attn_is_installed:
449
            logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
450
        use_flash_attention = False
451
452
453
        if use_fused_attention:
            logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
            use_fused_attention = False
454
    if device_compute_capability < (9, 0):
455
        if use_flash_attention and _flash_attn_3_is_installed:
456
            logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
457
        _use_flash_attn_3 = False
458
459

    # Filter: Data type
460
461
462
463
    if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
        torch.Tensor,
        Float8Tensor,
    ]:
464
        if use_flash_attention and _flash_attn_is_installed:
465
466
467
468
469
470
            logger.debug(
                "Disabling FlashAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
471
        use_flash_attention = False
472
473
474
475
476
477
478
479
        if use_fused_attention:
            logger.debug(
                "Disabling FusedAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
            use_fused_attention = False
480
481
482

    # Filter: Execution type
    if fp8 and fp8_meta["recipe"].fp8_dpa:
483
        if use_flash_attention and not _use_flash_attn_3:
484
485
            if _flash_attn_is_installed:
                logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
486
487
488
489
490
            use_flash_attention = False
        if use_flash_attention and _use_flash_attn_3 and is_training:
            logger.debug(
                "Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
            )
491
492
493
494
495
496
            use_flash_attention = False
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
            use_unfused_attention = False

    # Filter: Head dimension
497
    if use_flash_attention and head_dim_qk != head_dim_v:
498
499
        if _flash_attn_is_installed:
            logger.debug("Disabling FlashAttention as it does not support MLA.")
500
        use_flash_attention = False
501
    if use_flash_attention and (
502
503
504
        head_dim_qk > 256
        or head_dim_qk % 8 != 0
        or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
505
    ):
506
507
508
509
510
511
512
513
514
515
        if _flash_attn_is_installed:
            logger.debug(
                "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
                "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
                "head_dim_qk <= 256 (>192 requires sm80/90). "
                "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
                head_dim_qk,
                head_dim_v,
                ".".join([str(i) for i in device_compute_capability]),
            )
516
        use_flash_attention = False
517
518
519
520
521
522
523
    qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
    if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
        logger.debug(
            "Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
            qkv_layout,
        )
        use_fused_attention = False
524
525
526
527
528
529
530
531

    # Filter: QKV layout
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
            use_unfused_attention = False
        if use_flash_attention and pad_between_seqs:
532
533
534
535
536
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention for qkv_format = thd when there is "
                    "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
                )
537
538
            use_flash_attention = False

539
    # Filter: Dropout
540
541
542
    if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
        logger.debug("Disabling FlashAttention 3 for dropout")
        _use_flash_attn_3 = False
543

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    # Filter: Context parallelism
    # qkv_format | attn_mask_type              | attn_bias_type           | supported backends
    # ----------------------------------------------------------------------------------------------------
    # bshd, sbhd | self-attention:             | no_bias, post_scale_bias | FlashAttention, FusedAttention
    #            |     no_mask, causal         |                          |
    #            | cross-attention:            |                          |
    #            |     no_mask                 |                          |
    # thd        | self-attention:             | no_bias                  | FlashAttention, FusedAttention
    #            |     padding, padding_causal |                          | if no padding between sequences,
    #            | cross-attention:            |                          | FusedAttention
    #            |     padding                 |                          | if there is padding between sequences
    # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
    if context_parallel and use_unfused_attention:
        logger.debug(
            "Disabling UnfusedDotProductAttention as it does not support context parallelism"
        )
        use_unfused_attention = False
    if context_parallel and use_flash_attention:
562
        if fp8 and fp8_meta["recipe"].fp8_dpa:
563
564
565
566
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with FP8"
                )
567
            use_flash_attention = False
568
        if "bottom_right" in attn_mask_type:
569
570
571
572
573
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " causal_bottom_right masking"
                )
574
575
            use_flash_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
576
577
578
579
580
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " causal masking for cross-attention"
                )
581
582
            use_flash_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
583
584
585
586
587
588
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with bias"
                    " type of %s",
                    core_attention_bias_type,
                )
589
590
            use_flash_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
591
592
593
594
595
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " attention bias for THD format"
                )
596
            use_flash_attention = False
597

598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
    if context_parallel and use_fused_attention:
        if "bottom_right" in attn_mask_type:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with"
                " causal_bottom_right masking"
            )
            use_fused_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with causal"
                " masking for cross-attention"
            )
            use_fused_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with bias type"
                " of %s",
                core_attention_bias_type,
            )
            use_fused_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with attention"
                " bias for THD format"
            )
            use_fused_attention = False
        elif head_dim_qk != head_dim_v:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with MLA"
            )
            use_fused_attention = False

630
    # Filter: Attention mask
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    # attn_mask_type              | attention_mask                       | supported backends
    # ----------------------------------------------------------------------------------------
    # no_mask                     | None                                 | All
    # padding                     |                                      | All
    #     self-attention          | One tensor in shape [b, 1, 1, sq]    |
    #     cross-attention         | Tuple of two tensors in shapes       |
    #                             | [b, 1, 1, sq] and [b, 1, 1, skv]     |
    # causal                      | None                                 |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # padding_causal              | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # causal_bottom_right         | None                                 | All
645
    # padding_causal_bottom_right | Same as "padding"                    | All
646
647
    # arbitrary                   | One tensor in shape broadcastable to | UnfusedDotProductAttention
    #                             | [b, h, sq, skv]                      |
648
    if attn_mask_type == "arbitrary":
649
        if use_flash_attention and _flash_attn_is_installed:
650
651
652
653
654
            logger.debug("Disabling FlashAttention for arbitrary mask")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention for arbitrary mask")
        use_fused_attention = False
655
656
    if (
        use_flash_attention
657
        and _use_flash_attn_3
658
659
660
661
662
663
664
665
666
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        _use_flash_attn_3 = False
667
668
669
670
671
    if (
        use_flash_attention
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
672
673
674
675
676
677
678
679
680
        if _flash_attn_2_1_plus:
            logger.warning(
                "Disabling FlashAttention as it only supports bottom-right-diagonal "
                "causal mask since flash-attn 2.1. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False
        if not _flash_attn_is_installed:
            _flash_attn_max_version = PkgVersion("2.1")
681
682
683
684
685
    if (
        use_flash_attention
        and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
        and max_seqlen_q != max_seqlen_kv
    ):
686
687
688
689
690
691
692
693
694
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.1")
        elif not _flash_attn_2_1_plus and not _use_flash_attn_3:
            logger.warning(
                "Disabling FlashAttention as it only supports top-left-diagonal "
                "causal mask before flash-attn 2.1. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False
695
696
697
698
699
700
701
702
703
    if (
        use_flash_attention
        and _use_flash_attn_3
        and fp8
        and fp8_meta["recipe"].fp8_dpa
        and "padding" in attn_mask_type
    ):
        logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
        _use_flash_attn_3 = False
704
705

    # Filter: Sliding window attention
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    #    backend                 |      window_size       | diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | (-1, -1) or (>=0, >=0) | bottom right
    # FusedAttention             | (-1,  0) or (>=0, 0)   | top left
    # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
    #                            |                        | converts window_size to an 'arbitrary' mask
    if window_size is None:
        window_size = check_set_window_size(attn_mask_type, window_size)
    else:
        if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention"
                    " for FP8"
                )
                use_fused_attention = False
722
            elif window_size[1] != 0 or attention_dropout != 0.0:
723
724
                logger.debug(
                    "Disabling FusedAttention as it only supports sliding window attention "
725
                    "with (left, 0) and no dropout"
726
727
                )
                use_fused_attention = False
728
            elif max_seqlen_q > max_seqlen_kv:
729
730
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
731
                    "with s_q > s_kv for cross-attention"
732
733
                )
                use_fused_attention = False
734
735
736
737
738
739
740
741
742
743
744
745
746
        if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if _use_flash_attn_3:
                logger.debug(
                    "Disabling FlashAttention 3 as it does not support sliding window attention"
                )
                _use_flash_attn_3 = False
            if not _flash_attn_is_installed:
                _flash_attn_version_required = PkgVersion("2.3")
            elif not _flash_attn_2_3_plus:
                logger.debug(
                    "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
                )
                use_flash_attention = False
747
748

    # Filter: Attention bias
749
750
751
752
753
754
755
756
    #    backend                 |      bias types              | ALiBi diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | no_bias, alibi/alibi_slopes  | bottom right
    # FusedAttention             | no_bias, post_scale_bias     |
    #                            | alibi/alibi_slopes           | top left,
    #                            |                              | bottom_right (converts to a 'post_scale_bias' bias)
    # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
    #                            | alibi/alibi_slopes           | both; converts to a 'post_scale_bias' bias
757
    if use_flash_attention and core_attention_bias_type == "alibi":
758
        if _use_flash_attn_3:
759
760
            logger.debug("Disabling FlashAttention 3 for ALiBi")
            _use_flash_attn_3 = False
761
762
763
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.4")
        elif not _flash_attn_2_4_plus:
764
765
            logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
            use_flash_attention = False
766

767
768
769
770
    if use_flash_attention and (
        core_attention_bias_type not in ["no_bias", "alibi"]
        or core_attention_bias_shape is not None
    ):
771
772
        if _flash_attn_is_installed:
            logger.debug("Disabling FlashAttention for pre/post_scale_bias")
773
774
775
776
777
778
779
780
        use_flash_attention = False

    fu_core_attention_bias_type = core_attention_bias_type
    fu_core_attention_bias_shape = core_attention_bias_shape
    fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
    if (
        use_fused_attention
        and core_attention_bias_type == "alibi"
781
        and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
782
783
784
    ):
        fu_core_attention_bias_type = "post_scale_bias"
        fu_core_attention_bias_requires_grad = False
785
786
787
788
789
        if alibi_slopes_shape is None:
            fu_core_attention_bias_shape = "1hss"
        elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
            fu_core_attention_bias_shape = "1hss"
        elif (
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
            len(alibi_slopes_shape) == 2
            and alibi_slopes_shape[0] == batch_size
            and alibi_slopes_shape[1] == num_heads
        ):
            fu_core_attention_bias_shape = "bhss"

    if (
        use_fused_attention
        and fu_core_attention_bias_type == "post_scale_bias"
        and fu_core_attention_bias_shape != "1hss"
    ):
        if fu_core_attention_bias_requires_grad:
            # remove this line when cuDNN adds bwd support for
            # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
            logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
            use_fused_attention = False
        else:
            # max512 backend will only support [1, h, s, s]
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

    # Filter: cuDNN support
    fused_attention_backend = None
    if use_fused_attention:
        q_type = TE_DType[qkv_dtype]
        kv_type = q_type
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            kv_type = q_type
        fused_attention_backend = tex.get_fused_attn_backend(
            q_type,
            kv_type,
            QKVLayout[qkv_layout],
            AttnBiasType[fu_core_attention_bias_type],
            AttnMaskType[attn_mask_type],
            attention_dropout,
            num_heads,
            num_gqa_groups,
            max_seqlen_q,
            max_seqlen_kv,
829
830
            head_dim_qk,
            head_dim_v,
831
832
            window_size[0],
            window_size[1],
833
        )
834
        if fused_attention_backend == FusedAttnBackend["No_Backend"]:
835
836
            logger.debug("Disabling FusedAttention as no backend supports the provided input")
            use_fused_attention = False
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
            fused_attention_backend = None
        if (
            use_fused_attention
            and window_size is not None
            and window_size[0] != -1
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "slidng window attention",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
854
855
856
857
858
859
860
861
            and fu_core_attention_bias_type == "post_scale_bias"
            and fu_core_attention_bias_shape != "1hss"
        ):
            logger.debug(
                "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
                " [1, H, S, S] shape"
            )
            use_fused_attention = False
862
            fused_attention_backend = None
863
864
865
866
867
868
869
870
871
872
873
874
875

    # Filter: Determinism
    # backend                      | deterministic
    # ---------------------------------------------
    # FlashAttention               |
    #     flash-attn >=2.0, <2.4.1 | no
    #     flash-attn >=2.4.1       | yes
    # FusedAttention               |
    #     sub-backend 0            | yes
    #     sub-backend 1            | workspace optimization path and sm90+: yes;
    #                              | otherwise: no
    #     sub-backend 2            | no
    # UnfusedDotProductAttention   | yes
876
877
878
879
880
881
882
883
884
885
    if use_flash_attention and deterministic:
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.4.1")
        elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3:
            logger.warning(
                "Disabling FlashAttention as version <2.4.1 does not support deterministic "
                "execution. To use FlashAttention with deterministic behavior, "
                "please install flash-attn >= 2.4.1."
            )
            use_flash_attention = False
886
887
888
889
890
891
892
893
894
895
896
    if use_fused_attention and deterministic:
        if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (
                device_compute_capability < (9, 0)
                or core_attention_bias_requires_grad
                or cudnn_version < (8, 9, 5)
897
            )
898
899
900
        ):
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
901
902
903

    # All available backends
    available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920

    # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
    # When `FusedAttention` does not support the provided attention params, and `FlashAttention`
    # does, we recommend users to install flash-attn if not installed already.
    if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed:
        logger.warning(
            "flash-attn may provide important feature support or performance improvement."
            " Please install flash-attn %s.",
            _get_supported_versions(
                _flash_attn_version_required,
                _flash_attn_max_version,
            ),
        )
    if use_flash_attention and not _flash_attn_is_installed:
        use_flash_attention = False
        available_backends[0] = False

921
922
923
924
925
926
927
928
929
930
931
932
    logger.debug(
        "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
        " UnfusedDotProductAttention=%s}",
        bool(available_backends[0]),
        bool(available_backends[1]),
        (
            f" (sub-backend {int(fused_attention_backend)})"
            if fused_attention_backend is not None
            else ""
        ),
        bool(available_backends[2]),
    )
933
934
935
936
937
938
939
940
941
942
943
944
945

    # Select FusedAttention for performance
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
    ):
        if device_compute_capability == (9, 0):
            logger.debug(
                "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                "for performance reasons"
            )
            use_flash_attention = False
946
947
948
949
950
951
952
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["FP8"]
        and _use_flash_attn_3
    ):
        logger.debug(
953
954
            "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
            "in FP8 execution"
955
956
957
        )
        use_flash_attention = False

958
959
960
961
962
963
    # Selected backend
    if use_flash_attention:
        use_fused_attention = False
        use_unfused_attention = False
    elif use_fused_attention:
        use_unfused_attention = False
964
    selected_backend = "NoBackend"
965
966
967
968
969
970
    if use_flash_attention:
        selected_backend = "FlashAttention"
    elif use_fused_attention:
        selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
    elif use_unfused_attention:
        selected_backend = "UnfusedDotProductAttention"
971
    logger.debug("Selected backend = %s", selected_backend)
972

973
974
975
976
977
978
    global _attention_backends
    _attention_backends["use_flash_attention"] = use_flash_attention
    _attention_backends["use_fused_attention"] = use_fused_attention
    _attention_backends["fused_attention_backend"] = fused_attention_backend
    _attention_backends["use_unfused_attention"] = use_unfused_attention
    _attention_backends["backend_selection_requires_update"] = False
979
980
981
982

    return (
        use_flash_attention,
        use_fused_attention,
983
        fused_attention_backend,
984
985
986
987
988
        use_unfused_attention,
        available_backends,
    )


989
class InferenceParams:  # pylint: disable=too-few-public-methods
990
991
    """
    Inference parameters that are passed to the main model in order
992
    to efficiently calculate and store the context during inference.
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
1033

1034

1035
@torch.no_grad()
1036
def get_full_mask(
1037
1038
1039
    max_seqlen_q: int,
    max_seqlen_kv: int,
    attn_mask_type: str = "no_mask",
1040
1041
1042
1043
    attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
    window_size: Tuple[int, int] = None,
    attention_type: str = "self",
    bottom_right_alignment: bool = True,
1044
1045
) -> torch.Tensor:
    """
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
    `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
    on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::

       attn_mask_type              output shape                                 diagonal alignment
       --------------------------------------------------------------------------------------------
       no_mask                     [1, 1, max_seqlen_q, max_seqlen_kv]          follow bottom_right_alignment
       causal                      [1, 1, max_seqlen_q, max_seqlen_kv]          always top left
       causal_bottom_right         [1, 1, max_seqlen_q, max_seqlen_kv]          always bottom right
       padding                     [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
       padding_causal              [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
       padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
       arbitrary                   same as attention_mask                       follow bottom_right_alignment

    .. note::

    For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
    diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
    i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
    max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
    [[False, False,  True, True], [False, False, False, False]],
    [[False, False, False, True], [False,  True,  True,  True]]), the returned full attention mask has [2, 4, 4]
    shape and is,::

      [[[False, False, False, True],
        [False, False, False, True],
        [ True,  True,  True, True],
        [ True,  True,  True, True]],
       [[False,  True,  True, True],
        [False,  True,  True, True],
        [False,  True,  True, True],
        [False,  True,  True, True]]]
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087

    Parameters
    ----------
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
        "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
1088
    attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
1089
        default = `None`
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
        for the requirements of `attention_mask` for different `attn_mask_type`s.
    window_size: Tuple[int, int], default = `None`
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
        window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
        map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
        `attn_mask_type`.
    attention_type: str, default = "self"
        Attention type, {"self", "cross"}
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
        or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
        specifies "causal" or "causal_bottom_right".
1105
1106
1107

    Returns
    ----------
1108
1109
    attn_mask_type: str
        For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
1110
    attention_mask: torch.Tensor
1111
1112
1113
1114
1115
1116
1117
        The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
    actual_seqlens_q: torch.Tensor
        For padding masks, the actual sequence lengths for queries, in shape [batch_size].
        For other masks, `None`.
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
        For other masks, `None`.
1118
    """
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
    # perform basic checks
    change_type = window_size is not None and (
        window_size[0] != -1 or window_size[1] not in [-1, 0]
    )
    if window_size is None:
        window_size = (-1, -1)
    if "causal" in attn_mask_type:
        window_size = (window_size[0], 0)
    window_size = (
        max_seqlen_kv if window_size[0] == -1 else window_size[0],
        max_seqlen_q if window_size[1] == -1 else window_size[1],
    )

    # apply padding mask
    actual_seqlens_q = None
    actual_seqlens_kv = None
    if "padding" in attn_mask_type:
        if attention_type == "self":
            attention_mask = torch.logical_or(
                attention_mask.squeeze(1).unsqueeze(3), attention_mask
            )
        else:
            attention_mask = torch.logical_or(
                attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
            )
        m = attention_mask.logical_not()
        actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
        actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)

    # apply SWA mask
    mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
        1, 1, max_seqlen_q, 1
    ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
    swa_left = None
    swa_right = None
    if attn_mask_type == "causal_bottom_right" or (
        attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment
    ):
        swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
        swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
    elif attn_mask_type in ["causal", "padding_causal"] or (
        attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment
    ):
        swa_left = mask - window_size[0]
        swa_right = mask + window_size[1]
    elif attn_mask_type == "padding_causal_bottom_right" or (
        attn_mask_type == "padding" and bottom_right_alignment
    ):
        batch_size = attention_mask.shape[0]
        swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
            actual_seqlens_kv - actual_seqlens_q - window_size[0]
        ).view(batch_size, 1, 1, 1)
        swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
            actual_seqlens_kv - actual_seqlens_q + window_size[1]
        ).view(batch_size, 1, 1, 1)
    swa_mask = torch.logical_not(
        torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
    )
1177
    if attention_mask is not None:
1178
1179
1180
1181
1182
1183
1184
1185
1186
        attention_mask = torch.logical_or(swa_mask, attention_mask)
    else:
        attention_mask = swa_mask

    # change mask type
    if change_type:
        attn_mask_type = "arbitrary"

    return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv
1187
1188


1189
1190
1191
1192
1193
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
1194
1195
    actual_seqlens_q: Optional[torch.Tensor] = None,
    actual_seqlens_kv: Optional[torch.Tensor] = None,
1196
1197
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
1198
    bottom_right_alignment: bool = True,
1199
) -> Tuple[torch.Tensor, torch.Tensor]:
1200
    """
1201
1202
1203
1204
1205
1206
1207
1208
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
1209
1210
1211
1212
    actual_seqlens_q: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for queries, in shape [batch_size].
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for keys and values, in shape [batch_size].
1213
1214
1215
1216
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
1217
1218
1219
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the ALiBi bias to the bottom right corner of
        the matrix (`True`) or top left (`False`).
1220

1221
1222
1223
1224
1225
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
1226
1227
1228
1229
1230
1231
        ALiBi bias in FP32 or `bias_dtype`. Its shape is
        (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
        and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
        (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
        [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
        `actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
1255
        elif _alibi_cache["_alibi_slopes"].dim() == 2:
1256
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
1257
1258
1259
        else:
            raise ValueError("ALiBi slopes cannot exceed 2 dimensions.")

1260
        bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1261
            1, 1, max_seqlen_q, 1
1262
1263
        ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
            1, 1, 1, max_seqlen_kv
1264
        )
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        if actual_seqlens_q is None and actual_seqlens_kv is None:
            if bottom_right_alignment:
                bias = bias + max_seqlen_kv - max_seqlen_q
        elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
            batch_size = actual_seqlens_q.shape[0]
            bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
            if bottom_right_alignment:
                bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
        else:
            assert (
                False
            ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
1277
1278
1279
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
1280
        _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
1281
1282
1283
1284
1285
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
1286
1287
1288
1289
1290
1291
1292
1293
1294


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
1295
    reduced_mask = mask.logical_not().sum(dim=1)
1296
1297
1298
1299
1300
1301
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

1302

1303
1304
1305
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
1306
1307
1308
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
1309
1310
1311
1312
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

1313
    reduced_mask = mask.logical_not().sum(dim=1)
1314
1315
1316
1317
1318
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
1319
    indices = mask.logical_not().nonzero()
1320
1321
1322
1323
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
1324
1325
1326
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
1327
1328
1329
1330

    return cu_seqlens, indices


1331
1332
1333
1334
1335
1336
1337
1338
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
1339
1340
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
1341
1342
1343

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
1344
1345
1346
1347
1348
1349
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
1350
1351
1352

    return indices

1353

1354
_cu_seqlens_cache = {}
1355
1356


1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
1377
1378


1379
@torch.compile
1380
1381
1382
1383
1384
1385
1386
1387
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
1388
1389
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1390
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
1391
1392
1393
1394
1395
1396
1397
1398
    if isinstance(tensor, Float8Tensor):
        tensor_data = torch.cat((tensor._data, padding_indice), dim=0)

        packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices))
    else:
        tensor = torch.cat((tensor, padding_indice), dim=0)

        packed = torch.gather(tensor, 0, indices)
1399
1400
1401
    return packed


1402
@torch.compile
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


1416
@torch.compile
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


1432
@torch.compile
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
1443
1444
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1445
1446
1447
1448
1449
1450
    if isinstance(tensor, Float8Tensor):
        unpacked.scatter_(0, indices, tensor._data)
        unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :])
    else:
        unpacked.scatter_(0, indices, tensor)
        unpacked = unpacked[0:-1, :, :]
1451
1452
1453
    return unpacked


1454
@torch.compile
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


1469
@torch.compile
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
1490

1491
1492
    @staticmethod
    def forward(
1493
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
1494
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
1495
        # pylint: disable=missing-function-docstring
1496
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
1497
        ctx.save_for_backward(indices)
1498
1499
1500
1501
1502
1503
1504
1505
1506
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
1507
        # pylint: disable=missing-function-docstring
1508
        (indices,) = ctx.saved_tensors
1509
        if len(grad_outputs) == 1:
1510
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
1511
        if len(grad_outputs) == 2:
1512
1513
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
1514
1515
1516
1517
1518
1519


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
1520

1521
1522
1523
1524
1525
1526
1527
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
1528
        # pylint: disable=missing-function-docstring
1529
        ctx.save_for_backward(indices)
1530
1531
1532
1533
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
1534
        # pylint: disable=missing-function-docstring
1535
1536
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
1537
1538


1539
1540
1541
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
1542
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
1543
1544
1545
1546
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
1547
1548
1549
1550
1551
1552
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
1553
1554
1555
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
1556
1557
1558
1559
1560
1561
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


1581
@jit_fuser
1582
1583
1584
1585
1586
def flash_attn_fwd_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
1587
1588
    movedim_src: int,
    movedim_dst: int,
1589
):
1590
    """Merge partial outputs of each step in Attention with context parallelism"""
1591
1592
1593
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(
        movedim_src, movedim_dst
    )
1594
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
1595
    out_corrected = out_per_step * softmax_lse_corrected_exp
1596
1597
1598
    out.add_(out_corrected)


1599
@jit_fuser
1600
1601
1602
1603
def flash_attn_fwd_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
1604
    """Merge softmax stats of each step in Attention with context parallelism"""
1605
1606
1607
1608
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
1609
1610


1611
1612
@jit_fuser
def get_cu_seqlens_on_cp_rank(
1613
1614
1615
1616
1617
1618
    cu_seqlens: torch.Tensor,
    cu_seqlens_padded_on_cp_rank: torch.Tensor,
    cp_size: int,
    cp_rank: int,
    first_half: bool,
    second_half: bool,
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
@torch.compile
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
    To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
    before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
    sequence chunk ids for reordering.
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
    if to_contiguous:
        for rank in range(cp_size):
            chunk_ids[rank] = 2 * rank
            chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
    else:
        for rank in range(cp_size):
            chunk_ids[2 * rank] = rank
            chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
    return chunk_ids


@torch.compile
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
    """Reorder sequence chunk for A2A communication."""
    if before_attn:
        # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
        # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
        x = x.movedim(0, seq_dim).contiguous()
        # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
        # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
        x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
        # reorder the sequence chunks
        x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
    else:
        # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
        # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
        x = x.movedim(seq_dim, 0).contiguous()
        # reorder the sequence chunks
        x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
        # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
        # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
        x = x.view(cp_size, 2, *x.shape[1:])
    return x


def flash_attn_a2a_communicate(
    a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
    chunk_ids_for_a2a: torch.Tensor,
    seq_dim: int,
    cp_size: int,
    cp_group: dist_group_type,
    cp_stream: torch.cuda.Stream,
    before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
    """A2A communication for context parallelism."""
    a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
    a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
    if before_attn:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # reorder the sequence chunks
                    x = reorder_seq_chunks_for_a2a(
                        x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
                    )
                    # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
                    # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
                    a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, s, np, hn] -> [b, s, cp, np//cp, hn]
                # or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
                x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
                # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
                # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
                a2a_inputs[i] = x.movedim(-3, 0).contiguous()
    else:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
                # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
                x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
                # reorder the sequence chunks
                a2a_inputs[i] = reorder_seq_chunks_for_a2a(
                    x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
                    # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
                    x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
                    # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
                    # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
                    a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
    torch.cuda.current_stream().wait_stream(cp_stream)
    return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs


1749
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
1750
    """
1751
1752
1753
    Attention implementation with context parallelism. Exchange KV between CP ranks
    with P2P in ring topology. Split attention compute into multiple steps, and overlap
    current-step compute with next-step communication.
1754
1755
1756
1757
1758

    This implementation also supports hierarchical CP, which parallelizes attention
    heads in low-level CP groups and parallelizes sequence dimension in high-level CP
    groups. For more details, please refer to `LongVILA <https://arxiv.org/abs/2408.10188>`_
    and `USP <https://arxiv.org/abs/2405.07719>`_.
1759
1760
1761
    """

    @staticmethod
1762
1763
1764
1765
1766
1767
1768
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
1769
        cu_seqlens_kv,
1770
        max_seqlen_q,
1771
        max_seqlen_kv,
1772
1773
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
1774
1775
1776
1777
1778
1779
1780
1781
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
1782
1783
        fp8,
        fp8_meta,
1784
1785
1786
        cp_group,
        cp_global_ranks,
        cp_stream,
1787
    ):
1788
        # pylint: disable=missing-function-docstring
1789
1790
1791
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
        if isinstance(cp_group, list):
            assert (
                qkv_format != "thd"
            ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
            assert attn_bias_type == "no_bias", (
                f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
                " yet!"
            )
            cp_group_a2a = cp_group[0]
            cp_size_a2a = get_distributed_world_size(cp_group_a2a)
            rank_a2a = get_distributed_rank(cp_group_a2a)
            cp_group = cp_group[1]
        else:
            cp_group_a2a = None
            cp_size_a2a = 1
            rank_a2a = 0

1809
1810
        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
1811
1812
        send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
1813
1814
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1815
1816
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
1817

1818
        seq_dim = None
1819
        if qkv_format in ["bshd", "sbhd"]:
1820
            seq_dim = qkv_format.index("s")
1821
1822
1823
1824
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

1825
1826
1827
1828
1829
1830
        pad_between_seqs_q = cu_seqlens_q_padded is not None and not torch.equal(
            cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]
        )
        pad_between_seqs_kv = cu_seqlens_kv_padded is not None and not torch.equal(
            cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]
        )
1831
1832
        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
1833
1834
1835
1836
1837
1838
        cu_seqlens_q_padded = (
            None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // cp_size
        )
        cu_seqlens_kv_padded = (
            None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // cp_size
        )
1839
1840
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
1841

1842
1843
1844
        fused_attn_qkv_dtype = None
        fused_attn_backend = None
        amax_per_step = None
1845
1846
1847
1848
        qkv_dtype = q.dtype
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
        is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
1849
1850
1851
1852
1853
        if fp8:
            if use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_backend = FusedAttnBackend["FP8"]
1854
1855
1856
1857
1858
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
                if is_input_fp8:
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
                    fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                else:
                    q_f16, k_f16, v_f16 = q, k, v
                    if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                        q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                    if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                        k, v = [
                            cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                            for x in [k_f16, v_f16]
                        ]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
                fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            q_f16 = q
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True)
            q, k, v = flash_attn_a2a_communicate(
                [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
            )
            if not fp8:
                q_f16 = q
1897
            elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
1898
1899
1900
                q_f16 = q
                q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)

1901
1902
1903
        assert qkv_format == "thd" or (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"
1904
        if causal:
1905
1906
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
1907
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
1908
1909
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
1910
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
1911
        if attn_bias is not None:
1912
            assert len(attn_bias.shape) == 4, (
1913
1914
1915
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
1916
1917
1918
            assert (
                attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
            ), "Sequence length does not meet divisible requirements!"
1919
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
1920
1921
1922
1923
1924
1925
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
1926
1927
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
1928
1929
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
1930
            )
1931
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
1932

1933
1934
1935
1936
1937
1938
1939
        softmax_lse_in_packed_format = False
        if qkv_format == "thd":
            if use_fused_attention:
                softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
            else:
                softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3

1940
        flash_attn_fwd = None
1941
1942
1943
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
1944
1945
1946
1947
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
1948
1949
                fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
            else:
1950
1951
1952
1953
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
1954
1955
1956
1957
1958
1959
1960
1961
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
                if _flash_attn_2_3_plus:
                    fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_5_7_plus:
                    fa_forward_kwargs["block_table"] = None
1962

1963
1964
1965
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
1966
        attn_bias_inputs = [None, None]
1967
1968
1969
1970
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
1971
        attn_biases = [None for _ in range(cp_size)]
1972
1973
1974
1975
1976
1977
1978

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
1979
        if qkv_format in ["bshd", "sbhd"]:
1980
1981
1982
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
1983
1984
        send_recv_reqs = [[], []]

1985
1986
        softmax_lse_ = None
        out = None
1987
        for i in range(cp_size + 1):
1988
            if i < cp_size:
1989
                with torch.cuda.stream(flash_attn_streams[i % 2]):
1990
                    # wait until KV is received
1991
                    for req in send_recv_reqs[(i + 1) % 2]:
1992
1993
                        req.wait()

1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

2006
                    if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
                        kv_inputs[i % 2] = p2p_comm_buffers[i]
                    else:
                        # KV exchange is in BF16/FP16, cast received KV in each step
                        kv_inputs[i % 2] = cast_to_fp8(
                            p2p_comm_buffers[i],
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                        )
                    if fp8 and use_fused_attention:
2017
2018
2019
2020
                        fp8_meta_kwargs["amax_s"] = amax_per_step
                        fp8_meta_kwargs["amax_s_offset"] = i
                        fp8_meta_kwargs["amax_o"] = amax_per_step
                        fp8_meta_kwargs["amax_o_offset"] = cp_size + i
2021
2022
                    if causal:
                        if i == 0:
2023
2024
2025
2026
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
2027
                            elif use_fused_attention or qkv_format == "thd":
2028
2029
2030
2031
2032
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
2033
                            elif use_fused_attention or qkv_format == "thd":
2034
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
2051
                            if use_fused_attention:
2052
2053
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2054
2055
2056
2057
2058
2059
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
2060
                                    ).contiguous()
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
2089
                                )
2090
2091
2092
2093
2094
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2095
                            else:
2096
2097
2098
2099
2100
2101
2102
2103
                                fa_forward_args_thd = []
                                if qkv_format == "thd":
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q,
                                        max_seqlen_kv,
                                    ]
2104
                                fa_outputs = flash_attn_fwd(
2105
                                    q_inputs[i % 2],
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2117
                                    causal=True,
2118
                                    **fa_forward_kwargs,
2119
                                )
2120
2121
2122
2123
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[7]
2124
                        elif i <= rank:
2125
2126
2127
2128
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
2129
                            elif use_fused_attention or qkv_format == "thd":
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
2140
                            elif use_fused_attention or qkv_format == "thd":
2141
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][0]
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
                                # [2, t, np, hn] -> [2, t/2, np, hn]
                                kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                    kv_inputs[i % 2], cu_seqlens_kv_padded, 0
                                )
2158
                            if use_fused_attention:
2159
                                kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
2160
2161
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2162
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv // 2,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=(
                                        None
                                        if cu_seqlens_kv_padded is None
                                        else cu_seqlens_kv_padded // 2
                                    ),
                                    **fp8_meta_kwargs,
2195
                                )
2196
2197
2198
2199
2200
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2201
                            else:
2202
                                fa_forward_args_thd = []
2203
                                if qkv_format == "thd":
2204
2205
2206
2207
2208
2209
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q,
                                        max_seqlen_kv // 2,
                                    ]
2210
2211
2212
                                if _use_flash_attn_3 or _flash_attn_2_3_plus:
                                    fa_forward_kwargs["window_size"] = (-1, -1)
                                fa_outputs = flash_attn_fwd(
2213
                                    q_inputs[i % 2],
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2225
                                    causal=False,
2226
                                    **fa_forward_kwargs,
2227
                                )
2228
2229
2230
2231
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[7]
2232
                        else:
2233
2234
2235
2236
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
2237
                            elif use_fused_attention or qkv_format == "thd":
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
2248
                            elif use_fused_attention or qkv_format == "thd":
2249
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                q_inputs[i % 2] = q[:, 1, ...]
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                                q_inputs[i % 2] = q[1]
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                # [t, np, hn] -> [t/2, np, hn]
                                q_inputs[i % 2] = tex.thd_read_half_tensor(
                                    q, cu_seqlens_q_padded, 1
                                )
2269
                            if use_fused_attention:
2270
                                q_inputs[i % 2] = q_inputs[i % 2].contiguous()
2271
2272
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2273
2274
2275
2276
2277
2278
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
2279
                                    ).contiguous()
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q // 2,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=(
                                        None
                                        if cu_seqlens_q_padded is None
                                        else cu_seqlens_q_padded // 2
                                    ),
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
2312
                                )
2313
2314
2315
2316
2317
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2318
                            else:
2319
                                fa_forward_args_thd = []
2320
                                if qkv_format == "thd":
2321
2322
2323
2324
2325
2326
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q // 2,
                                        max_seqlen_kv,
                                    ]
2327
2328
2329
                                if _use_flash_attn_3 or _flash_attn_2_3_plus:
                                    fa_forward_kwargs["window_size"] = (-1, -1)
                                fa_outputs = flash_attn_fwd(
2330
                                    q_inputs[i % 2],
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2342
                                    causal=False,
2343
                                    **fa_forward_kwargs,
2344
                                )
2345
2346
2347
2348
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[7]
2349
                    else:
2350
2351
2352
2353
                        if pad_between_seqs_q:
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
2354
                        elif use_fused_attention or qkv_format == "thd":
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                        if pad_between_seqs_kv:
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
2365
                        elif use_fused_attention or qkv_format == "thd":
2366
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2367
                        if use_fused_attention:
2368
2369
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
2370
2371
2372
2373
2374
2375
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
2376
                                ).contiguous()
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
                            out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                is_training,
                                max_seqlen_q,
                                max_seqlen_kv,
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                fused_attn_qkv_dtype,
                                fused_attn_backend,
                                attn_scale=softmax_scale,
                                dropout=dropout_p,
                                qkv_layout=qkv_layout,
                                attn_mask_type=attn_mask_type,
                                attn_bias_type=attn_bias_type,
                                attn_bias=attn_bias_inputs[i % 2],
                                cu_seqlens_q_padded=cu_seqlens_q_padded,
                                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                **fp8_meta_kwargs,
2405
                            )
2406
2407
2408
2409
2410
                            if fp8:
                                softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                            else:
                                softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                attn_biases[i] = rest[0] if len(rest) > 0 else None
2411
                        else:
2412
2413
2414
2415
2416
2417
2418
2419
                            fa_forward_args_thd = []
                            if qkv_format == "thd":
                                fa_forward_args_thd = [
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                ]
2420
                            fa_outputs = flash_attn_fwd(
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                *fa_forward_args_thd,
2433
                                causal=False,
2434
                                **fa_forward_kwargs,
2435
                            )
2436
2437
2438
2439
                            out_per_step[i] = fa_outputs[4]
                            softmax_lse_per_step[i] = fa_outputs[5]
                            if not _use_flash_attn_3:
                                rng_states[i] = fa_outputs[7]
2440
2441
2442
2443

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
2444
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
2445

2446
                if use_fused_attention:
2447
2448
                    # [b, np, sq, 1] -> [b, np, sq] or
                    # [t, np, 1] -> [t, np]
2449
                    softmax_lse_per_step[i - 1].squeeze_(-1)
2450
2451
2452
2453
                    if softmax_lse_in_packed_format:
                        softmax_lse_per_step[i - 1] = (
                            softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
                        )
2454

2455
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
2456
2457
2458
2459
2460
2461
2462
2463
                    if fp8:
                        out_per_step[i - 1] = cast_from_fp8(
                            out_per_step[i - 1],
                            fp8_meta["scaling_fwd"],
                            META_O_CP,
                            fp8_dtype_forward,
                            TE_DType[torch.float32],
                        )
2464
                    if i == 1:
2465
                        out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
2466
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
2467
                        if causal and qkv_format != "thd":
2468
                            # [b, np, sq] -> [b, np, 2, sq//2]
2469
                            softmax_lse_ = softmax_lse.view(
2470
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
2471
                            )
2472
2473
2474
2475
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
2476
                    else:
2477
                        if qkv_format == "thd":
2478
                            tex.thd_second_half_lse_correction(
2479
2480
2481
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
2482
                                softmax_lse_in_packed_format,
2483
                            )
2484
                        else:
2485
2486
2487
                            flash_attn_fwd_softmax_lse_correction(
                                softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
                            )
2488
2489

                if i < cp_size:
2490
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
2491
2492
2493

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

2494
2495
2496
2497
        second_half_lse_seqlen = None
        if causal and rank < (cp_size - 1):
            second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1]

2498
2499
        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
2500
            if i <= rank or not causal:
2501
                if qkv_format in ["bshd", "sbhd"]:
2502
2503
2504
2505
2506
                    flash_attn_fwd_out_correction(
                        out.view(*out_per_step[i].shape),
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2507
2508
                        0 if softmax_lse_in_packed_format else 2,
                        2 if softmax_lse_in_packed_format else seq_dim,
2509
                    )
2510
                elif qkv_format == "thd":
2511
2512
2513
2514
2515
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2516
                        cu_seqlens_q_padded,
2517
                        False,
2518
                        softmax_lse_in_packed_format,
2519
                    )
2520
            else:
2521
                if qkv_format in ["bshd", "sbhd"]:
2522
                    out_ = out.select(seq_dim, 1)
2523
2524
2525
2526
2527
                    flash_attn_fwd_out_correction(
                        out_,
                        out_per_step[i],
                        softmax_lse_[..., 1, :],
                        softmax_lse_per_step[i],
2528
2529
                        0 if softmax_lse_in_packed_format else 2,
                        2 if softmax_lse_in_packed_format else seq_dim,
2530
                    )
2531
                elif qkv_format == "thd":
2532
2533
2534
2535
2536
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2537
                        cu_seqlens_q_padded,
2538
                        True,
2539
                        softmax_lse_in_packed_format,
2540
                    )
2541
2542

        kv = p2p_comm_buffers[-1]
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
        if qkv_format == "bshd":
            out = out.view(out.shape[0], -1, *out.shape[-2:])
            ctx.batch_size = out.shape[0]
        elif qkv_format == "sbhd":
            out = out.view(-1, *out.shape[-3:])
            ctx.batch_size = out.shape[1]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
            out = flash_attn_a2a_communicate(
                out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
            )
            if use_fused_attention:
                if qkv_format == "bshd":
                    # [b*s, np, hn] -> [b, s, np, hn]
                    out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                elif qkv_format == "sbhd":
                    # [s*b, np, hn] -> [s, b, np, hn]
                    out = out.view(-1, ctx.batch_size, *out.shape[-2:])
        elif not use_fused_attention:
2563
            out = out.view(-1, *out.shape[-2:])
2564

2565
2566
2567
2568
2569
        if fp8 and use_fused_attention:
            amax_cp_fwd = amax_per_step.amax(dim=1)
            fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0]
            fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]

2570
        out_fp8 = None
2571
2572
        out_f16 = out.to(qkv_dtype)
        if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
2573
2574
            out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)

2575
        if fp8 and is_output_fp8:
2576
2577
2578
2579
2580
2581
            out_ret = Float8Tensor(
                data=out_fp8,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_O,
                fp8_dtype=fp8_dtype_forward,
2582
                dtype=qkv_dtype,
2583
2584
2585
2586
2587
2588
2589
2590
            )
        else:
            out_ret = out_f16

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            q_save, kv_save, out_save = q, kv, out_fp8
            fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
            fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
2591
        elif fp8 and is_input_fp8:
2592
2593
2594
2595
2596
2597
2598
2599
            q_fp8 = Float8Tensor(
                data=q,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_QKV,
                fp8_dtype=fp8_dtype_forward,
                dtype=q_fp8.dtype,
            )
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
            kv_fp8 = Float8Tensor(
                data=kv,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_QKV,
                fp8_dtype=fp8_dtype_forward,
                dtype=k_fp8.dtype,
            )
            q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None
        else:
2611
            q_f16 = q_f16.view(q.shape)
2612
2613
2614
            q_save, kv_save, out_save = q_f16, kv, out_f16
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None

2615
        ctx.save_for_backward(
2616
2617
2618
            q_save,
            kv_save,
            out_save,
2619
            softmax_lse,
2620
2621
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
2622
2623
            fp8_fwd_scales,
            fp8_fwd_scale_invs,
2624
2625
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
2626
2627
            *rng_states,
            *attn_biases,
2628
        )
2629
2630
2631
        ctx.cp_group_a2a = cp_group_a2a
        ctx.cp_size_a2a = cp_size_a2a
        ctx.rank_a2a = rank_a2a
2632
2633
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
2634
        ctx.cp_stream = cp_stream
2635
2636
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
2637
        ctx.max_seqlen_kv = max_seqlen_kv
2638
        ctx.softmax_scale = softmax_scale
2639
        ctx.qkv_format = qkv_format
2640
        ctx.attn_mask_type = attn_mask_type
2641
2642
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
2643
        ctx.deterministic = deterministic
2644
        ctx.use_fused_attention = use_fused_attention
2645
        ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format
2646
        ctx.second_half_lse_seqlen = second_half_lse_seqlen
2647
2648
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
2649
2650
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
2651
        return out_ret
2652
2653
2654

    @staticmethod
    def backward(ctx, dout):
2655
        # pylint: disable=missing-function-docstring
2656
2657
2658
        cp_size_a2a = ctx.cp_size_a2a
        rank_a2a = ctx.rank_a2a

2659
2660
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
2661
2662
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
2663
2664
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

2665
2666
2667
2668
2669
2670
2671
        (*saved_tensors,) = ctx.saved_tensors
        (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6]
        (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8]
        cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size]
        cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2]
        rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
        attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
2672

2673
2674
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
2675
2676

        seq_dim = None
2677
        if ctx.qkv_format in ["bshd", "sbhd"]:
2678
            seq_dim = ctx.qkv_format.index("s")
2679
2680
2681
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
2682

2683
        if attn_biases[0] is not None:
2684
2685
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
2686
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
2687
2688
2689
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
2690
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
2691
2692
2693
            )
        else:
            attn_dbias = None
2694
            attn_dbias_ = None
2695

2696
2697
        softmax_lse_ = None
        if causal and ctx.second_half_lse_seqlen is not None:
2698
            if ctx.qkv_format == "thd":
2699
                softmax_lse_ = tex.thd_read_second_half_lse(
2700
2701
2702
2703
                    softmax_lse,
                    cu_seqlens_q_padded,
                    ctx.softmax_lse_in_packed_format,
                    ctx.second_half_lse_seqlen,
2704
                )
2705
2706
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
2707
2708
2709
                softmax_lse_ = softmax_lse.view(
                    *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
                )
2710
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
2711
2712
2713
2714
2715
2716
            if ctx.use_fused_attention:
                if ctx.softmax_lse_in_packed_format:
                    softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous()
                # [b, np, sq//2] -> [b, np, sq//2, 1] or
                # [t//2, np] -> [t//2, np, 1]
                softmax_lse_.unsqueeze_(-1)
2717
        if ctx.use_fused_attention:
2718
2719
2720
2721
            if ctx.softmax_lse_in_packed_format:
                softmax_lse = softmax_lse.transpose(0, 1).contiguous()
            # [b, np, sq] -> [b, np, sq, 1] or
            # [t, np] -> [t, np, 1]
2722
            softmax_lse.unsqueeze_(-1)
2723

2724
        dq = None
2725
        dout_dtype = dout.dtype
2726
2727
2728
2729
2730
        fused_attn_backend = None
        fused_attn_qkv_dtype = None
        fused_attn_dqkv_dtype = None
        amax_per_step = None
        dout_fp8_dtype = None
2731
2732
        if ctx.fp8:
            if ctx.use_fused_attention:
2733
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
2734
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2735
                fused_attn_qkv_dtype = fp8_dtype_forward
2736
2737
2738
2739
2740
                fused_attn_dqkv_dtype = fp8_dtype_backward
                fused_attn_backend = FusedAttnBackend["FP8"]
                dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
                dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
                dkv_fp8_ = torch.empty_like(dkv_fp8)
2741
                if ctx.is_output_fp8:
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
                    dout = dout._data
                else:
                    dout = cast_to_fp8(
                        dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                    )
                p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
                fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
                fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
                fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
                fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
                fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
                fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
                fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP]
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
2763
            if ctx.fp8_meta is not None and ctx.is_input_fp8:
2764
2765
2766
2767
2768
2769
2770
                q, kv = [x.from_float8(x.dtype) for x in [q, kv]]
                if cp_size_a2a == 1:
                    dout = dout.from_float8(dout_dtype)
                else:
                    dout_fp8_dtype = dout._fp8_dtype
                    dout_scale_inv = dout._scale_inv
                    dout = dout._data
2771
2772
2773
2774
2775
2776
2777
2778
2779
            dq = torch.empty_like(q)
            p2p_comm_buffers = [
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            ]
            p2p_comm_buffers[0][0].copy_(kv)
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
2780
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
2781
2782
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
        if cp_size_a2a > 1:
            if not ctx.use_fused_attention:
                out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                dout = dout.view(*out.shape)
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True)
            out, dout = flash_attn_a2a_communicate(
                [out, dout],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                True,
            )
2797
            if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
2798
                dout = cast_from_fp8(
2799
2800
2801
2802
2803
2804
                    dout,
                    None,
                    None,
                    dout_fp8_dtype,
                    TE_DType[dout_dtype],
                    scale_inv=dout_scale_inv,  # pylint: disable=used-before-assignment
2805
2806
                )

2807
2808
2809
2810
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        send_recv_reqs = []

2811
        flash_attn_bwd = None
2812
2813
2814
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
2815
2816
2817
2818
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
2819
2820
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
2821
2822
2823
2824
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
2825
2826
2827
2828
2829
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
2830

2831
2832
2833
2834
2835
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

2836
2837
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
            if ctx.fp8:
                if i < cp_size - 1:
                    send_recv_reqs = flash_attn_p2p_communicate(
                        rank,
                        send_tensor[0],
                        send_dst,
                        recv_tensor[0],
                        recv_src,
                        ctx.cp_group,
                        batch_p2p_comm,
                    )
                else:
                    dkv_a2a_req = torch.distributed.all_to_all_single(
                        dkv_fp8,
                        dkv_fp8_,
                        group=ctx.cp_group,
                        async_op=True,
                    )
                    send_recv_reqs = [dkv_a2a_req]
            else:
                if i == 0:
                    send_tensor = send_tensor[0]
                    recv_tensor = recv_tensor[0]
                if i == (cp_size - 1):
                    send_tensor = send_tensor[1]
                    recv_tensor = recv_tensor[1]
                send_recv_reqs = flash_attn_p2p_communicate(
                    rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
                )
2867

2868
            kv = p2p_comm_buffers[i % 2][0]
2869
2870
            q_, kv_, out_, dout_ = None, None, None, None
            dq_, dk_, dv_ = None, None, None
2871
2872
2873
            if ctx.fp8 and ctx.use_fused_attention:
                fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
                fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
2874
            # In reversed order of fwd
2875
            if causal:
2876
                if i == (cp_size - 1):
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        q_, kv_, out_, dout_ = q, kv, out, dout
2891
                    if ctx.use_fused_attention:
2892
2893
2894
2895
2896
2897
2898
2899
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2900
                        if attn_dbias is not None:
2901
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2902
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2903
                            ctx.max_seqlen_q,
2904
2905
2906
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2907
                            q_,
2908
2909
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2910
2911
                            out_,
                            dout_,
2912
2913
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2914
                            aux_ctx_tensors,
2915
                            fused_attn_backend,
2916
2917
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2918
2919
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2920
                            qkv_layout=qkv_layout,
2921
                            attn_mask_type=ctx.attn_mask_type,
2922
                            attn_bias_type=ctx.attn_bias_type,
2923
2924
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2925
2926
                        )
                    else:
2927
                        dq_ = torch.empty_like(q_)
2928
                        dkv_ = torch.empty_like(kv_)
2929
2930
2931
2932
2933
2934
2935
2936
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q,
                                ctx.max_seqlen_kv,
                            ]
2937
2938
2939
2940
2941
                        if _use_flash_attn_3 or _flash_attn_2_3_plus:
                            fa_backward_kwargs["window_size"] = (-1, 0)
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
2942
2943
                            dout_,
                            q_,
2944
2945
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2946
2947
2948
                            out_,
                            softmax_lse,
                            dq_,
2949
2950
2951
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
2952
2953
                            causal=True,
                            **fa_backward_kwargs,
2954
                        )
2955
                elif i >= (cp_size - rank - 1):
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                        kv_ = kv[:, 0]
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                        kv_ = kv[0]
                    elif ctx.qkv_format == "thd":
                        q_, out_, dout_ = q, out, dout
                        # [2, t, np, hn] -> [2, t/2, np, hn]
                        kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2972
                    if ctx.use_fused_attention:
2973
                        kv_ = kv_.contiguous()
2974
2975
2976
2977
2978
2979
2980
2981
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2982
                        if attn_dbias is not None:
2983
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2984
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2985
                            ctx.max_seqlen_q,
2986
2987
2988
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2989
                            q_,
2990
2991
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2992
2993
                            out_,
                            dout_,
2994
2995
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2996
                            aux_ctx_tensors,
2997
                            fused_attn_backend,
2998
2999
3000
3001
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
3002
3003
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
3004
                            qkv_layout=qkv_layout,
3005
                            attn_mask_type="padding" if padding else "no_mask",
3006
                            attn_bias_type=ctx.attn_bias_type,
3007
3008
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
3009
3010
                        )
                    else:
3011
                        dq_ = torch.empty_like(q_)
3012
                        dkv_ = torch.empty_like(kv_)
3013
3014
3015
3016
3017
3018
3019
3020
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q,
                                ctx.max_seqlen_kv // 2,
                            ]
3021
3022
3023
3024
3025
                        if _use_flash_attn_3 or _flash_attn_2_3_plus:
                            fa_backward_kwargs["window_size"] = (-1, -1)
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
3026
3027
                            dout_,
                            q_,
3028
3029
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3030
3031
3032
                            out_,
                            softmax_lse,
                            dq_,
3033
3034
3035
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
3036
3037
                            causal=False,
                            **fa_backward_kwargs,
3038
3039
                        )
                else:
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                        q_, out_, dout_ = q[1], out[1], dout[1]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        # [t, np, hn] -> [t/2, np, hn]
                        q_, out_, dout_ = [
                            tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1)
                            for x in [q, out, dout]
                        ]
                        kv_ = kv
3057
                    if ctx.use_fused_attention:
3058
                        q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]]
3059
3060
3061
3062
3063
3064
3065
3066
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse_,
                                softmax_lse_,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
3067
                        if attn_dbias is not None:
3068
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3069
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3070
                            ctx.max_seqlen_q // 2,
3071
3072
3073
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
3074
                            q_,
3075
3076
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3077
3078
                            out_,
                            dout_,
3079
3080
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
3081
                            aux_ctx_tensors,
3082
                            fused_attn_backend,
3083
3084
3085
3086
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3087
3088
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
3089
                            qkv_layout=qkv_layout,
3090
                            attn_mask_type="padding" if padding else "no_mask",
3091
                            attn_bias_type=ctx.attn_bias_type,
3092
3093
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
3094
3095
                        )
                    else:
3096
                        dq_ = torch.empty_like(q_)
3097
                        dkv_ = torch.empty_like(kv_)
3098
                        fa_backward_args_thd = []
3099
                        if ctx.qkv_format == "thd":
3100
3101
3102
3103
3104
3105
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q // 2,
                                ctx.max_seqlen_kv,
                            ]
3106
3107
3108
3109
3110
                        if _use_flash_attn_3 or _flash_attn_2_3_plus:
                            fa_backward_kwargs["window_size"] = (-1, -1)
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
3111
3112
                            dout_,
                            q_,
3113
3114
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3115
3116
3117
                            out_,
                            softmax_lse_,
                            dq_,
3118
3119
3120
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
3121
3122
                            causal=False,
                            **fa_backward_kwargs,
3123
3124
3125
                        )
            else:
                if ctx.use_fused_attention:
3126
3127
3128
3129
                    if ctx.fp8:
                        aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
                    else:
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
3130
                    if attn_dbias is not None:
3131
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3132
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3133
                        ctx.max_seqlen_q,
3134
3135
3136
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
3137
                        q,
3138
3139
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
3140
3141
                        out,
                        dout,
3142
3143
                        fused_attn_qkv_dtype,
                        fused_attn_dqkv_dtype,
3144
                        aux_ctx_tensors,
3145
                        fused_attn_backend,
3146
3147
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3148
3149
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
3150
                        qkv_layout=qkv_layout,
3151
                        attn_mask_type=ctx.attn_mask_type,
3152
                        attn_bias_type=ctx.attn_bias_type,
3153
3154
                        deterministic=ctx.deterministic,
                        **fp8_meta_kwargs,
3155
3156
                    )
                else:
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
                    dq_ = torch.empty_like(q)
                    dkv_ = torch.empty_like(kv)
                    fa_backward_args_thd = []
                    if ctx.qkv_format == "thd":
                        fa_backward_args_thd = [
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_kv,
                        ]
3167
3168
3169
3170
3171
                    if _use_flash_attn_3 or _flash_attn_2_3_plus:
                        fa_backward_kwargs["window_size"] = (-1, -1)
                    if not _use_flash_attn_3:
                        fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                    flash_attn_bwd(
3172
3173
3174
3175
3176
                        dout,
                        q,
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
                        out,
3177
3178
                        softmax_lse,
                        dq_,
3179
3180
3181
                        dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                        dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                        *fa_backward_args_thd,
3182
3183
                        causal=False,
                        **fa_backward_kwargs,
3184
3185
                    )

3186
3187
            if ctx.fp8:
                dq = dq_fp8[(rank + i + 1) % cp_size]
3188
3189
3190
            if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1):
                # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or
                # [sq, b, np, hn] -> [2, sq//2, b, np, hn]
3191
                dq_ = dq_.view(*dq.shape)
3192

3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
            if ctx.fp8:
                if i >= (cp_size - rank - 1) or not causal:
                    dq.copy_(dq_)
                else:
                    if ctx.qkv_format == "bshd":
                        dq[:, 0, ...].fill_(0)
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[0].fill_(0)
                        dq[1].copy_(dq_)
            elif causal:
3204
                if i > (cp_size - rank - 1):
3205
                    dq.add_(dq_)
3206
3207
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
3208
3209
                        dq.copy_(dq_)
                    else:
3210
3211
3212
3213
3214
3215
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
3216
                        elif ctx.qkv_format == "thd":
3217
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
3218
                elif i > 0:
3219
3220
3221
3222
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
3223
                    elif ctx.qkv_format == "thd":
3224
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
3225
                else:
3226
3227
3228
3229
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
3230
                    elif ctx.qkv_format == "thd":
3231
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
3232
3233
3234
3235
3236
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
3237

3238
            if attn_dbias is not None:
3239
                idx = (rank + i + 1) % cp_size
3240
                if i == (cp_size - 1) or not causal:
3241
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
3242
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
3243
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
3244
3245
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
3246
3247
3248
3249
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
3250
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
3251
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
3252
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
3253

3254
3255
3256
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
3257

3258
3259
3260
3261
3262
3263
3264
            if ctx.fp8:
                if i < cp_size - 1:
                    dkv = dkv_fp8_[(rank + i + 1) % cp_size]
                else:
                    dkv = dkv_fp8[(rank + i + 1) % cp_size]
            else:
                dkv = p2p_comm_buffers[(i + 1) % 2][1]
3265
            if ctx.use_fused_attention:
3266
                if ctx.qkv_format in ["bshd", "sbhd"]:
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
                    dkv_ = _combine_tensors([dk_, dv_], -2)
                elif ctx.qkv_format == "thd":
                    dkv_ = torch.cat(
                        (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
                    )  # pylint: disable=used-before-assignment
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
                dkv_ = dkv_.movedim(-3, 0)
                if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
                    # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(*dkv.shape)
3281

3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
            if ctx.fp8:
                if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
                    if ctx.qkv_format == "bshd":
                        dkv[:, :, 0, ...].copy_(dkv_)
                        dkv[:, :, 1, ...].fill_(0)
                    elif ctx.qkv_format == "sbhd":
                        dkv[:, 0, ...].copy_(dkv_)
                        dkv[:, 1, ...].fill_(0)
                else:
                    dkv.copy_(dkv_)
            elif causal:
3293
                if i == (cp_size - 1):
3294
                    if rank == 0:
3295
3296
3297
3298
3299
3300
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
3301
                        elif ctx.qkv_format == "thd":
3302
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
3303
3304
                    else:
                        dkv.add_(dkv_)
3305
3306
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
3307
3308
3309
3310
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
3311
                        elif ctx.qkv_format == "thd":
3312
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
3313
                    else:
3314
3315
3316
3317
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
3318
                        elif ctx.qkv_format == "thd":
3319
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
3320
3321
3322
3323
3324
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
3325
3326
3327
3328
3329
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
        if ctx.fp8 and ctx.use_fused_attention:
            amax_cp_bwd = amax_per_step.amax(dim=1)
            ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0]
            ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1]
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
                # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
                dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
            dq, dkv = [
                cast_from_fp8(
                    x,
                    ctx.fp8_meta["scaling_bwd"],
                    META_DQKV_CP,
                    fp8_dtype_backward,
                    TE_DType[torch.float32],
                )
                for x in [dq_fp8, dkv_fp8]
            ]
            dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]

3350
        if causal:
3351
3352
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
3353
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
3354
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
3355
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
3356
3357
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
3358
                dq = dq.view(-1, *dq.shape[-3:])
3359
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
3360
3361
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

3362
3363
3364
        if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
            dq[cu_seqlens_q_padded[-1] :].fill_(0)
            dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
3365

3366
        if ctx.fp8 and ctx.is_input_fp8:
3367
3368
3369
3370
            dq, dkv = [
                cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
                for x in [dq, dkv]
            ]
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
        dk, dv = dkv[0], dkv[1]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False)
            dq, dk, dv = flash_attn_a2a_communicate(
                [dq, dk, dv],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                False,
            )
            if ctx.qkv_format == "bshd":
                dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
            elif ctx.qkv_format == "sbhd":
                dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

3389
        if ctx.fp8 and ctx.is_input_fp8:
3390
3391
3392
3393
3394
3395
3396
3397
3398
            dq, dk, dv = [
                Float8Tensor(
                    data=x,
                    fp8_meta=ctx.fp8_meta,
                    fp8_meta_forward=False,
                    fp8_meta_index=META_DQKV,
                    fp8_dtype=fp8_dtype_backward,
                    dtype=dout_dtype,
                )
3399
                for x in [dq, dk, dv]
3400
3401
            ]

3402
3403
3404
3405
        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

3406
3407
3408
        return (
            None,
            dq,
3409
3410
            dk,
            dv,
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3422
            attn_dbias,
3423
3424
3425
3426
3427
            None,
            None,
            None,
            None,
            None,
3428
3429
            None,
            None,
3430
        )
3431
3432


3433
3434
def get_kv_seq_info_after_all_gather(
    local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
3435
):
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
    """Compute KV sequence index range and update window size after all-gather."""
    local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
    full_seq_end_idx = max_seqlen_kv * cp_size * 2

    if window_size is None:
        window_size = (-1, 0) if causal else (-1, -1)

    if window_size[1] == -1:
        seq_end_idx = full_seq_end_idx
        window_size_right = -1
    else:
        seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
        window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx

    if window_size[0] == -1:
        seq_start_idx = 0
        window_size_left = -1
    else:
        seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
        window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx

    return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
3458
3459
3460
3461


class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    """
3462
3463
    Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
    Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
3486
3487
        cp_group,
        cp_stream,
3488
    ):
3489
        # pylint: disable=missing-function-docstring
3490
3491
3492
3493
3494
3495
3496
3497
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
3498
        assert not padding, f"{attn_mask_type} mask type is not supported!"
3499
3500
3501
3502
3503
3504
3505
        if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
            attn_mask_type = attn_mask_type + "_bottom_right"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            use_fused_attention or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
3506

3507
        flash_attn_fwd = None
3508
3509
3510
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
3511
3512
3513
3514
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
3515
            else:
3516
3517
3518
3519
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
3520
3521
3522
3523
3524
3525
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_5_7_plus:
                    fa_forward_kwargs["block_table"] = None
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        max_seqlen_q = max_seqlen_q // (2 * cp_size)
        max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
3537
3538
3539
3540
3541
        if use_fused_attention or qkv_format == "thd":
            cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
        cu_seqlens_q_padded = (
            None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // (2 * cp_size)
        )
3542

3543
3544
3545
3546
        # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
        q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
        # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
        k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
3547

3548
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3549
3550
        k_ag, _ = gather_along_first_dim(k, cp_group)
        v_ag, _ = gather_along_first_dim(v, cp_group)
3551
3552

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3553
3554
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        cp_stream.wait_stream(torch.cuda.current_stream())

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
3565
3566

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
3567
3568
3569
        kv_seq_range_per_step = [None, None]
        window_size_per_step = [None, None]
        cu_seqlens_kv_per_step = [None, None]
3570
3571
3572
3573
3574
3575
3576
3577
        out_per_step = [None, None]
        softmax_lse_per_step = [None, None]
        rng_states = [None, None]
        out = torch.empty_like(q)

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3578
3579
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3580
3581
3582
3583
3584
3585
3586
3587
3588
                    q_ = q.select(seq_dim, i).contiguous()
                    kv_seq_range_per_step[i], window_size_per_step[i] = (
                        get_kv_seq_info_after_all_gather(
                            local_seq_chunk_ids[i],
                            cp_size,
                            max_seqlen_q,
                            max_seqlen_kv,
                            window_size,
                            causal,
3589
                        )
3590
3591
3592
3593
3594
3595
                    )
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv_ = seq_end_idx - seq_start_idx
3596
3597
3598
3599
                    if use_fused_attention or qkv_format == "thd":
                        cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
                            k.shape[1], max_seqlen_kv_, k.device
                        )
3600
3601
3602
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
3603
3604
3605
3606
                    if use_fused_attention:
                        out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
                            is_training,
                            max_seqlen_q,
3607
                            max_seqlen_kv_,
3608
                            cu_seqlens_q,
3609
                            cu_seqlens_kv_per_step[i],
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
                            q_,
                            k_,
                            v_,
                            TE_DType[q.dtype],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=softmax_scale,
                            dropout=dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=attn_mask_type,
                            attn_bias_type=attn_bias_type,
                            attn_bias=attn_bias,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
3622
3623
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
                            window_size=window_size_per_step[i],
3624
3625
                        )
                    else:
3626
3627
3628
3629
3630
3631
3632
3633
                        fa_forward_args_thd = []
                        if qkv_format == "thd":
                            fa_forward_args_thd = [
                                cu_seqlens_q,
                                cu_seqlens_kv_per_step[i],
                                max_seqlen_q,
                                max_seqlen_kv_,
                            ]
3634
3635
3636
3637
                        fa_outputs = flash_attn_fwd(
                            q_,
                            k_,
                            v_,
3638
                            *fa_forward_args_thd,
3639
3640
3641
                            causal=causal,
                            window_size=window_size_per_step[i],
                            **fa_forward_kwargs,
3642
                        )
3643
3644
3645
3646
                        out_per_step[i] = fa_outputs[4]
                        softmax_lse_per_step[i] = fa_outputs[5]
                        if not _use_flash_attn_3:
                            rng_states[i] = fa_outputs[7]
3647
3648
3649
3650

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if qkv_format == "bshd":
3651
                        out[:, i - 1].copy_(out_per_step[i - 1])
3652
                    elif qkv_format == "sbhd":
3653
                        out[i - 1].copy_(out_per_step[i - 1])
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670

        torch.cuda.current_stream().wait_stream(cp_stream)

        if use_fused_attention:
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
        else:
            out = out.view(-1, *out.shape[-2:])

        ctx.save_for_backward(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_q_padded,
3671
            *cu_seqlens_kv_per_step,
3672
3673
3674
3675
            *out_per_step,
            *softmax_lse_per_step,
            *rng_states,
        )
3676
3677
        ctx.kv_seq_range_per_step = kv_seq_range_per_step
        ctx.window_size_per_step = window_size_per_step
3678
3679
3680
3681
3682
3683
3684
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_bias_type = attn_bias_type
3685
        ctx.attn_mask_type = attn_mask_type
3686
3687
3688
3689
3690
3691
        ctx.deterministic = deterministic
        ctx.use_fused_attention = use_fused_attention
        return out

    @staticmethod
    def backward(ctx, dout):
3692
        # pylint: disable=missing-function-docstring
3693
3694
3695
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)

3696
3697
3698
3699
3700
3701
        (*saved_tensors,) = ctx.saved_tensors
        (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
        cu_seqlens_kv_per_step = saved_tensors[5:7]
        out_per_step = saved_tensors[7:9]
        softmax_lse_per_step = saved_tensors[9:11]
        rng_states = saved_tensors[11:13]
3702
3703
        kv_seq_range_per_step = ctx.kv_seq_range_per_step
        window_size_per_step = ctx.window_size_per_step
3704

3705
        seq_dim = ctx.qkv_format.index("s")
3706
3707
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

3708
        dout = dout.view(q.shape)
3709
        dq = torch.empty_like(q)
3710
        dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
        dv = torch.zeros_like(dk)
        dq_per_step = [None, None]
        dk_per_step = [None, None]
        dv_per_step = [None, None]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
        # synchronize dkv update across steps
        dkv_update_done = torch.cuda.Event()

3721
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3722
3723
        k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
        v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
3724
3725

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3726
3727
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3728
3729
3730
3731
3732
3733
3734
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        ctx.cp_stream.wait_stream(torch.cuda.current_stream())
3735
3736
3737

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]

3738
        flash_attn_bwd = None
3739
3740
3741
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
3742
3743
3744
3745
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
3746
3747
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
3748
3749
3750
3751
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
3752
3753
3754
3755
3756
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
3757
3758
3759
3760

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3761
3762
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3763
3764
3765
3766
3767
3768
3769
3770
3771
                    q_ = q.select(seq_dim, i).contiguous()
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv = seq_end_idx - seq_start_idx
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
3772
                    out_ = out_per_step[i]
3773
                    dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
3774
3775
3776
3777
                    if ctx.use_fused_attention:
                        aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
                            ctx.max_seqlen_q,
3778
                            max_seqlen_kv,
3779
                            cu_seqlens_q,
3780
                            cu_seqlens_kv_per_step[i],
3781
3782
3783
3784
3785
3786
                            q_,
                            k_,
                            v_,
                            out_,
                            dout_,
                            TE_DType[q.dtype],
3787
                            TE_DType[dout.dtype],
3788
3789
3790
                            aux_ctx_tensors,
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
3791
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
3792
3793
3794
3795
3796
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=ctx.attn_mask_type,
                            attn_bias_type=ctx.attn_bias_type,
3797
3798
                            window_size=window_size_per_step[i],
                            deterministic=ctx.deterministic,
3799
3800
3801
3802
3803
                        )
                    else:
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
3804
3805
3806
3807
3808
3809
3810
3811
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q,
                                cu_seqlens_kv_per_step[i],
                                ctx.max_seqlen_q,
                                max_seqlen_kv,
                            ]
3812
3813
3814
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[i]
                        flash_attn_bwd(
3815
3816
3817
3818
3819
3820
3821
3822
3823
                            dout_,
                            q_,
                            k_,
                            v_,
                            out_,
                            softmax_lse_per_step[i],
                            dq_per_step[i],
                            dk_per_step[i],
                            dv_per_step[i],
3824
                            *fa_backward_args_thd,
3825
                            causal="causal" in ctx.attn_mask_type,
3826
                            window_size=window_size_per_step[i],
3827
                            **fa_backward_kwargs,
3828
3829
3830
3831
3832
                        )

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if ctx.qkv_format == "bshd":
3833
                        dq[:, i - 1].copy_(dq_per_step[i - 1])
3834
                    elif ctx.qkv_format == "sbhd":
3835
3836
3837
3838
3839
3840
                        dq[i - 1].copy_(dq_per_step[i - 1])
                    # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
                    dk_per_step[i - 1], dv_per_step[i - 1] = [
                        x.movedim(seq_dim, 0).contiguous()
                        for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
                    ]
3841
3842
3843
                    # wait until dkv update of last step is done
                    if i > 1:
                        flash_attn_streams[i - 1].wait_event(dkv_update_done)
3844
3845
3846
3847
3848
3849
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i - 1][0],
                        kv_seq_range_per_step[i - 1][1],
                    )
                    dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
                    dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
3850
3851
3852
3853
3854
                    if i < len(local_seq_chunk_ids):
                        flash_attn_streams[i - 1].record_event(dkv_update_done)

        torch.cuda.current_stream().wait_stream(ctx.cp_stream)

3855
3856
3857
3858
3859
3860
3861
        # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
        dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
        dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False)
        dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
        dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
3862
3863
3864
3865
3866
        dk = dk.view(-1, *dk.shape[-3:])
        dv = dv.view(-1, *dv.shape[-3:])
        dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
        dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)

3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881
3882
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
3910
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924
3925
3926
        dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
        dk = dk.movedim(0, seq_dim).contiguous()
        dv = dv.movedim(0, seq_dim).contiguous()

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
    """
    Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
    Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
        fp8,
        fp8_meta,
        cp_group,
        cp_stream,
    ):
3927
        # pylint: disable=missing-function-docstring
3928
3929
3930
3931
3932
3933
3934
3935
3936
3937
3938
3939
3940
3941
3942
3943
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
        assert not padding, f"{attn_mask_type} mask type is not supported!"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            window_size == (-1, 0)
            or window_size == (-1, -1)
            or use_fused_attention
            or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
3944

3945
        flash_attn_fwd = None
3946
3947
3948
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
3949
3950
3951
3952
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
3953
3954
                fa_forward_kwargs["window_size"] = window_size
            else:
3955
3956
3957
3958
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
3959
3960
3961
3962
3963
3964
3965
3966
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
                if _flash_attn_2_3_plus:
                    fa_forward_kwargs["window_size"] = window_size
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_5_7_plus:
                    fa_forward_kwargs["block_table"] = None
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980

        assert (
            q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
        ), "The number of attention heads needs to be divisible by CP size!"

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        batch_dim = qkv_format.index("b")
        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

3981
        qkv_dtype = q.dtype
3982
3983
        fused_attn_backend = None
        fused_attn_qkv_dtype = None
3984
3985
3986
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
        is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
3987
3988
3989
3990
3991
        if fp8:
            if use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_backend = FusedAttnBackend["FP8"]
3992
3993
3994
3995
3996
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
                if is_input_fp8:
3997
3998
3999
4000
4001
4002
4003
4004
4005
4006
4007
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
4020
4021
4022
4023
4024
4025
4026
4027
4028
4029
4030
4031
                    fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                    q_f16, k_f16, v_f16 = q, k, v
                    q, k, v = [
                        cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                        for x in [q_f16, k_f16, v_f16]
                    ]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
                fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_o_offset"] = META_O
                fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history
                fp8_meta_kwargs["amax_s_offset"] = META_S
                fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history
                fp8_meta_kwargs["amax_o_offset"] = META_O
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True)
        q, k, v = flash_attn_a2a_communicate(
            [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
        )

4032
        if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
4047
4048
4049
4050
4051
4052
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
4063
            q_f16, k_f16, v_f16 = q, k, v
            q, k, v = [
                cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                for x in [q_f16, k_f16, v_f16]
            ]

        batch_size = q.shape[batch_dim]
        if use_fused_attention:
            out, aux_ctx_tensors = fused_attn_fwd(
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                fused_attn_qkv_dtype,
                fused_attn_backend,
                attn_scale=softmax_scale,
                dropout=dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                attn_bias=attn_bias,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                window_size=window_size,
                **fp8_meta_kwargs,
            )
        else:
4064
4065
4066
4067
4068
4069
4070
4071
            fa_forward_args_thd = []
            if qkv_format == "thd":
                fa_forward_args_thd = [
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                ]
4072
            fa_outputs = flash_attn_fwd(
4073
4074
4075
                q,
                k,
                v,
4076
                *fa_forward_args_thd,
4077
                causal=causal,
4078
                **fa_forward_kwargs,
4079
            )
4080
4081
            out, softmax_lse = fa_outputs[4], fa_outputs[5]
            rng_state = fa_outputs[7] if not _use_flash_attn_3 else None
4082
4083
4084
4085
4086
4087
4088
4089
4090
4091
4092
4093
4094
4095
4096
4097
            aux_ctx_tensors = [softmax_lse, rng_state]

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False)
        out = flash_attn_a2a_communicate(
            out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
        )

        if use_fused_attention:
            if qkv_format == "bshd":
                # [b*s, np, hn] -> [b, s, np, hn]
                out = out.view(batch_size, -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                # [s*b, np, hn] -> [s, b, np, hn]
                out = out.view(-1, batch_size, *out.shape[-2:])

        if fp8:
4098
            if is_output_fp8:
4099
4100
4101
4102
4103
4104
                out_fp8 = Float8Tensor(
                    data=out,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
4105
                    dtype=qkv_dtype,
4106
4107
4108
4109
4110
4111
4112
4113
4114
4115
4116
4117
4118
4119
4120
4121
4122
4123
                )
                out = out_fp8._data
                out_ret = out_fp8
            else:
                out_f16 = cast_from_fp8(
                    out,
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    TE_DType[q_f16.dtype],
                )
                out_ret = out_f16
        else:
            out_ret = out

        if fp8:
            if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                q_save, k_save, v_save, out_save = q, k, v, out
4124
            elif is_input_fp8:
4125
4126
4127
4128
4129
4130
4131
4132
4133
4134
4135
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
4156
4157
4158
4159
4160
4161
4162
4163
4164
4165
4166
4167
4168
4169
4170
4171
4172
4173
4174
4175
                q_fp8, k_fp8, v_fp8 = [
                    Float8Tensor(
                        data=x,
                        fp8_meta=fp8_meta,
                        fp8_meta_forward=True,
                        fp8_meta_index=META_QKV,
                        fp8_dtype=fp8_dtype_forward,
                        dtype=out_fp8.dtype,
                    )
                    for x in [q, k, v]
                ]
                q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8
            else:
                q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16
        else:
            q_save, k_save, v_save, out_save = q, k, v, out

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
            fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
        else:
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None

        ctx.save_for_backward(
            q_save,
            k_save,
            v_save,
            out_save,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            fp8_fwd_scales,
            fp8_fwd_scale_invs,
            *aux_ctx_tensors,
        )
        ctx.batch_size = batch_size
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_mask_type = attn_mask_type
        ctx.attn_bias_type = attn_bias_type
        ctx.deterministic = deterministic
        ctx.window_size = window_size
        ctx.use_fused_attention = use_fused_attention
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
4176
4177
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
4178
4179
4180
4181
        return out_ret

    @staticmethod
    def backward(ctx, dout):
4182
        # pylint: disable=missing-function-docstring
4183
4184
        cp_size = get_distributed_world_size(ctx.cp_group)

4185
4186
4187
4188
4189
        (*saved_tensors,) = ctx.saved_tensors
        q, k, v, out = saved_tensors[:4]
        cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8]
        fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10]
        aux_ctx_tensors = saved_tensors[10:]
4190
4191
4192
4193
4194

        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
        causal = "causal" in ctx.attn_mask_type
        seq_dim = ctx.qkv_format.index("s")

4195
4196
4197
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
        fused_attn_qkv_dtype = None
4198
        dout_dtype = dout.dtype
4199
4200
4201
4202
4203
4204
4205
        if ctx.fp8:
            if ctx.use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_dqkv_dtype = fp8_dtype_backward
                fused_attn_backend = FusedAttnBackend["FP8"]
4206
                if ctx.is_output_fp8:
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
4223
4224
4225
4226
4227
4228
4229
4230
4231
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
                    dout_fp8 = dout
                    dout = dout_fp8._data
                else:
                    dout_f16 = dout
                    dout = cast_to_fp8(
                        dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                    )
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
                fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
                fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
                fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
                fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
                fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
                fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
                fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV]
                fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP]
                fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][
                    META_DQKV
                ]
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
4232
            if ctx.fp8_meta is not None and ctx.is_output_fp8:
4233
4234
4235
4236
4237
4238
4239
4240
4241
4242
4243
4244
4245
4246
4247
4248
4249
                assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]]
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_dqkv_dtype = TE_DType[dout.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if not ctx.use_fused_attention:
            out = out.view(ctx.batch_size, -1, *out.shape[-2:])
        dout = dout.view(*out.shape)

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True)
        out, dout = flash_attn_a2a_communicate(
            [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
        )

4250
        flash_attn_bwd = None
4251
4252
4253
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
4254
4255
4256
4257
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
4258
4259
4260
                fa_backward_kwargs["window_size"] = ctx.window_size
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
4261
4262
4263
4264
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
4265
4266
4267
4268
4269
4270
4271
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_3_plus:
                    fa_backward_kwargs["window_size"] = ctx.window_size
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
4272
4273
4274
4275
4276
4277
4278
4279
4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
4299
4300
4301

        if ctx.use_fused_attention:
            dq, dk, dv, _ = fused_attn_bwd(
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                out,
                dout,
                fused_attn_qkv_dtype,
                fused_attn_dqkv_dtype,
                aux_ctx_tensors,
                fused_attn_backend,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                attn_scale=ctx.softmax_scale,
                dropout=ctx.dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=ctx.attn_mask_type,
                attn_bias_type=ctx.attn_bias_type,
                window_size=ctx.window_size,
                deterministic=ctx.deterministic,
                **fp8_meta_kwargs,
            )
        else:
            softmax_lse, rng_state = aux_ctx_tensors
            dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
4302
4303
4304
4305
4306
4307
4308
4309
            fa_backward_args_thd = []
            if ctx.qkv_format == "thd":
                fa_backward_args_thd = [
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    ctx.max_seqlen_q,
                    ctx.max_seqlen_kv,
                ]
4310
4311
4312
            if not _use_flash_attn_3:
                fa_backward_kwargs["rng_state"] = rng_state
            flash_attn_bwd(
4313
4314
4315
4316
4317
4318
4319
4320
4321
                dout,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
4322
                *fa_backward_args_thd,
4323
4324
                causal=causal,
                **fa_backward_kwargs,
4325
4326
4327
4328
4329
4330
4331
            )

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False)
        dq, dk, dv = flash_attn_a2a_communicate(
            [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
        )

4332
        if ctx.qkv_format == "bshd":
4333
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
4334
        elif ctx.qkv_format == "sbhd":
4335
4336
4337
            dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8:
4338
            if ctx.is_input_fp8:
4339
4340
4341
4342
4343
4344
4345
                dq, dk, dv = [
                    Float8Tensor(
                        data=x,
                        fp8_meta=ctx.fp8_meta,
                        fp8_meta_forward=False,
                        fp8_meta_index=META_DQKV,
                        fp8_dtype=fp8_dtype_backward,
4346
                        dtype=dout_dtype,
4347
4348
4349
4350
4351
4352
4353
4354
4355
4356
                    )
                    for x in [dq, dk, dv]
                ]
            else:
                dq, dk, dv = [
                    cast_from_fp8(
                        x,
                        ctx.fp8_meta["scaling_bwd"],
                        META_DQKV,
                        fp8_dtype_backward,
4357
                        TE_DType[dout_dtype],
4358
4359
4360
                    )
                    for x in [dq, dk, dv]
                ]
4361
4362
4363
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4375
4376
4377
4378
4379
4380
4381
4382
4383

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4384
4385
4386
            None,
            None,
            None,
4387
4388
4389
        )


4390
def attn_forward_func_with_cp(
4391
4392
4393
4394
4395
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
4396
    cu_seqlens_kv,
4397
    max_seqlen_q,
4398
    max_seqlen_kv,
4399
4400
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
4401
4402
4403
4404
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
4405
    cp_comm_type,
4406
4407
4408
4409
4410
4411
4412
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
4413
    window_size=None,
4414
4415
    fp8=False,
    fp8_meta=None,
4416
) -> torch.Tensor:
4417
4418
4419
4420
    """
    Attention implementation with context parallelism.
    """

4421
4422
4423
4424
4425
4426
4427
4428
4429
4430
4431
4432
4433
4434
4435
4436
    if cp_comm_type == "a2a+p2p":
        assert isinstance(
            cp_group, list
        ), "Hierarchical CP implementation needs multi-level CP groups!"
        assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
        if get_distributed_world_size(cp_group[0]) == 1:
            cp_group = cp_group[1]
            cp_comm_type = "p2p"
        elif get_distributed_world_size(cp_group[1]) == 1:
            cp_group = cp_group[0]
            cp_comm_type = "a2a"
    else:
        assert isinstance(
            cp_group, dist_group_type
        ), f"Unsupported process group for CP communication type {cp_comm_type}!"

4437
4438
4439
4440
4441
4442
4443
4444
4445
4446
4447
4448
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
4449
    assert qkv_format != "thd" or (
4450
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
4451
    ), "cu_seqlens_padded cannot be None with context parallelism + THD format!"
4452
4453
4454

    sliding_window_attn = (
        window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
4455
    )
4456
4457
4458
4459
    assert not sliding_window_attn or cp_comm_type in [
        "a2a",
        "all_gather",
    ], "The context parallel running configs cannot support sliding window attetnion!"
4460

4461
4462
4463
4464
4465
4466
4467
4468
4469
4470
4471
4472
4473
4474
4475
4476
4477
4478
4479
4480
4481
    args = [
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ]

4482
    if cp_comm_type in ["p2p", "a2a+p2p"]:
4483
4484
4485
4486
4487
4488
4489
4490
4491
4492
        args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream]
        out = AttnFuncWithCPAndKVP2P.apply(*args)
    elif cp_comm_type == "all_gather":
        args.pop(5)
        args.pop(8)
        args += [window_size, cp_group, cp_stream]
        out = AttnFuncWithCPAndKVAllGather.apply(*args)
    elif cp_comm_type == "a2a":
        args += [window_size, fp8, fp8_meta, cp_group, cp_stream]
        out = AttnFuncWithCPAndQKVOA2A.apply(*args)
4493
4494
4495
    else:
        raise ValueError(f"Unsupported communication type: {cp_comm_type}!")

4496
4497
4498
    return out


4499
4500
4501
4502
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
4503

4504
4505
4506
    def __init__(
        self,
        dim: int,
4507
        rotary_percent: float = 1.0,
4508
4509
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
4510
        rotary_base: float = 10000.0,
4511
4512
4513
4514
4515
4516
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
4517
4518
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
4519
4520
4521
4522
4523
4524
4525
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
4526
4527
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
4528
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
4529
        self.rotary_base = rotary_base
4530
        inv_freq = 1.0 / (
4531
            self.rotary_base
4532
4533
4534
4535
4536
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
4537
        self.register_buffer("inv_freq", inv_freq)
4538
4539
4540
4541
4542
4543
4544
4545
4546
4547
4548
4549
4550
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
4551
4552
4553
4554
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
4555

4556
4557
4558
4559
4560
4561
4562
4563
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
4564
4565
4566
4567
4568
4569
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

4570
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
4571
4572
4573
4574
4575
4576
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

4577
4578
4579
4580
4581
4582
4583
4584
4585
4586
4587
4588
4589
4590
4591
4592
4593

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
4594
4595
        cp_size: int = 1,
        cp_rank: int = 0,
4596
    ) -> torch.Tensor:
4597
        # pylint: disable=missing-function-docstring
4598
4599
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
4600
4601
4602
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
4603
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
4604
        elif tensor_format == "thd":
4605
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
4606
4607
4608
4609
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format
4610
4611
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
4612
4613
4614
4615

        return output

    @staticmethod
4616
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
4617
        # pylint: disable=missing-function-docstring
4618
4619
4620
4621
4622
4623
4624
4625
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
4626
4627
4628
            grad_input = tex.fused_rope_thd_backward(
                grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
            )
4629
4630
4631
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

4632
        return grad_input, None, None, None, None, None
4633
4634


4635
4636
4637
4638
4639
4640
4641
4642
4643
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


4644
def apply_rotary_pos_emb(
4645
4646
4647
4648
4649
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
4650
4651
    cp_size: int = 1,
    cp_rank: int = 0,
4652
) -> torch.Tensor:
4653
    """
4654
    Apply rotary positional embedding tensor to the input tensor.
4655

4656
4657
4658
    Parameters
    ----------
    t: torch.Tensor
4659
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
4660
4661
4662
4663
4664
4665
4666
4667
4668
4669
4670
4671
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
4672
4673
4674
4675
4676
        Should be `cu_seqlens_padded` when cp_size > 1.
    cp_size: int, default = 1.
        Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
    cp_rank: int, default = 0.
        Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
4677
    """
4678
4679
4680
4681
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
4682
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
4683
4684
4685
4686
4687
4688

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

4689
4690
4691
4692
4693
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
4694
4695
4696
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
4697
    freqs = freqs[:cur_seq_len]
4698
    if tensor_format == "bshd":
4699
4700
4701
4702
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
4703

4704
4705
4706
4707
4708
4709
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
4710
    t = (t * cos_) + (_rotate_half(t) * sin_)
4711
4712
4713
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
4714
class _SplitAlongDim(torch.autograd.Function):
4715
4716
4717
    """"""

    @staticmethod
4718
4719
4720
4721
4722
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
4723
    ) -> Tuple[torch.Tensor, ...]:
4724
        # pylint: disable=missing-function-docstring
cyanguwa's avatar
cyanguwa committed
4725
4726
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
4727
        if isinstance(mixed_x_layer, Float8Tensor):
4728
4729
4730
4731
4732
4733
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
                    data=x,
                )
                for x in torch.split(
4734
4735
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
4736
4737
4738
4739
                    dim=split_dim,
                )
            )
        return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
4740
4741

    @staticmethod
4742
    def backward(ctx, *grad_outputs):
4743
        # pylint: disable=missing-function-docstring
4744
4745
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
4746
4747
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
4748
4749
4750
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
4751
4752
4753
4754
4755
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

4756
4757
4758
4759
4760
4761
4762
4763
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
4764
4765
4766
4767
4768
4769
4770
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
4771
4772
4773
                    noop_ok = False
                    break
            if noop_ok:
4774
4775
4776
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
4777
4778
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
4779
4780
4781
4782
4783
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
4784
4785
4786
4787
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
4788
4789
4790
4791
4792
4793
4794
            return (
                Float8Tensor.make_like(
                    grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim)
                ),
                None,
                None,
            )
4795
4796
        noop_ok = True
        strides = grad_outputs[0].stride()
4797
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
4798
        shape = list(grad_outputs[0].shape)
4799
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
4800
4801
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
4802
4803
4804
4805
4806
4807
4808
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
4809
4810
4811
                noop_ok = False
                break
        if noop_ok:
4812
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
4813
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
4814
            new_shape[split_dim] = sum(split_sizes)
4815
4816
4817
4818
4819
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
4820
            )
cyanguwa's avatar
cyanguwa committed
4821
            return ret, None, None
4822

4823
        return torch.cat(grad_outputs, dim=split_dim), None, None
4824
4825
4826
4827
4828
4829
4830
4831
4832


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
4833
        softmax_scale: float,
4834
        attention_type: str = "self",
4835
4836
4837
4838
4839
4840
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

4841
        self.softmax_scale = softmax_scale
4842
        self.attention_type = attention_type
4843
4844
4845
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

4846
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
4847
4848
4849
4850
4851
4852

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

4853
4854
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
4855
4856
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
4857

4858
4859
4860
4861
4862
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4863
        qkv_layout: str = "sbh3d",
4864
4865
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
4866
        attn_mask_type: str = "causal",
4867
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4868
        window_size: Optional[Tuple[int, int]] = None,
4869
4870
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
4871
        alibi_slopes: Optional[torch.Tensor] = None,
4872
    ) -> torch.Tensor:
4873
        """Unfused attention fprop"""
4874
4875
4876
4877
4878
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
4879
            # convert to sbhd and use sbhd implementation for now
4880
4881
4882
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
4883
4884
4885
4886
4887
        batch_size, max_seqlen_q, max_seqlen_kv = (
            query_layer.shape[1],
            query_layer.shape[0],
            key_layer.shape[0],
        )
4888
4889
4890
4891
4892
4893
4894
4895
4896

        attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
            max_seqlen_q,
            max_seqlen_kv,
            attn_mask_type=attn_mask_type,
            attention_mask=attention_mask,
            window_size=window_size,
            attention_type=self.attention_type,
        )
4897

4898
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
4899
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
4900
4901
4902
4903
4904
4905
4906
4907
4908

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

4909
        if key_layer.shape[2] != query_layer.shape[2]:
4910
4911
4912
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
4913
            key_layer = key_layer.repeat_interleave(
4914
4915
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
4916
            value_layer = value_layer.repeat_interleave(
4917
4918
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
4919

4920
        # [sq, b, np, hn] -> [sq, b * np, hn]
4921
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
4922
4923
4924
4925
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
4926
4927
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
4928
4929
4930
4931
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
4932
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
4933
4934
4935
            device=torch.cuda.current_device(),
        )

4936
4937
4938
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

4939
        scale = self.softmax_scale
4940
        if apply_qk_layer_scaling:
4941
            scale /= self.layer_number
4942
4943

        # Raw attention scores. [b * np, sq, sk]
4944
4945
4946
4947
4948
4949
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
4950
                alpha=scale,
4951
            ).view(*output_size)
4952
4953
4954
4955
4956
4957
4958

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
4959
            matmul_result = matmul_result.view(*output_size) + core_attention_bias
4960
            matmul_result *= scale
4961

4962
4963
4964
4965
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
4966
                _, core_attention_bias = get_alibi(
4967
4968
4969
                    output_size[1],
                    output_size[2],
                    output_size[3],
4970
4971
                    actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
                    actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
4972
4973
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
4974
                )
4975
4976
4977
4978
4979
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
4980
                alpha=scale,
4981
            )
4982
4983
            matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
                dtype=query_layer.dtype
4984
            )
4985
4986
4987

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
4988
        attention_probs = self.scale_mask_softmax(
4989
            matmul_result, attention_mask, attn_mask_type, softmax_scale
4990
        )
4991

4992
4993
4994
4995
4996
        # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
        # the columns (pad tokens from k) are already zeroed out during softmax
        if "padding" in attn_mask_type:
            attention_probs = attention_probs.masked_fill(attention_mask, 0)

4997
4998
4999
5000
5001
5002
5003
5004
5005
5006
5007
5008
5009
5010
5011
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
5012
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
5013
5014

        # change view [b * np, sq, sk]
5015
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
5016
5017
5018
5019
5020
5021
5022

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

5023
        if qkv_format == "sbhd":
5024
5025
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
5026

5027
5028
5029
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

5030
        if qkv_format == "bshd":
5031
5032
5033
5034
5035
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
5036
5037
5038
5039
5040
5041

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
5042
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
5043
5044

    @staticmethod
5045
5046
5047
5048
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
5049
        value_layer: torch.Tensor,
5050
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
5051
        # pylint: disable=missing-function-docstring
5052
5053
5054
5055
5056
5057
5058
5059
5060
5061
5062
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
5063
5064
5065
5066
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
5067
        dv: torch.Tensor,
5068
    ) -> Tuple[Union[torch.Tensor, None], ...]:
5069
        # pylint: disable=missing-function-docstring
5070
5071
5072
5073
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

5074

5075
def get_qkv_layout(
5076
5077
5078
5079
5080
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
5081
    """Get qkv layout.
5082

5083
5084
5085
5086
5087
5088
5089
5090
5091
5092
5093
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
5094
        `d` head size, and `t` the total number of tokens in a batch, i.e.
5095
5096
5097
5098
5099
5100
5101
5102
5103
5104
5105
5106
5107
5108
5109
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
5110
5111
5112
5113
5114
5115
5116
5117
5118
    q: torch.Tensor
        Query tensor. It may be different from input `q` as we try to fit tensors to
        a supported layout.
    k: torch.Tensor
        Key tensor. It may be different from input `k` as we try to fit tensors to
        a supported layout.
    v: torch.Tensor
        Value tensor. It may be different from input `v` as we try to fit tensors to
        a supported layout.
5119
    """
5120

5121
5122
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
5123

5124
    def run_iteratively(q, k, v):
5125
        # check data pointers
5126
5127
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
5128
        check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
5129
5130
5131
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

5132
5133
5134
5135
5136
5137
5138
        # check tensor shapes
        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = shape[:-1] == v.shape[:-1]

        # check tensor strides
5139
5140
        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
5141
5142
        check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
            sv / v.shape[-1] for sv in v.stride()[:-1]
5143
        )
5144

5145
5146
5147
5148
5149
5150
        # check tensor offsets for h3d and 3hd layouts
        prod_h_d = q.shape[-1] * q.shape[-2]
        check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v]))
        check_h3d_offsets = all(
            x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v])
        )
5151

5152
5153
5154
5155
5156
5157
        # check tensor offsets for hd_h2d and hd_2hd layouts
        prod_all_dims = [np.prod(x.shape) for x in [q, k]]
        offset = prod_all_dims[0] if check_ptrs_qkv else 0
        prod_h_d = k.shape[-1] * k.shape[-2]
        check_2hd_offsets = all(
            x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v])
5158
        )
5159
5160
        check_h2d_offsets = all(
            x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v])
5161
        )
5162

5163
5164
5165
5166
5167
5168
5169
5170
5171
5172
        # check tensor offsets for hd_hd_hd layouts
        check_hd_offsets_qkv = (
            all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v]))
            if check_ptrs_qkv
            else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v]))
        )
        check_hd_offsets_qk = (
            all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k]))
            if not check_ptrs_qkv and check_ptrs_qk
            else all(x.storage_offset() == 0 for i, x in enumerate([q, k]))
5173
        )
5174
5175
5176
5177
        check_hd_offsets_kv = (
            all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v]))
            if not check_ptrs_qkv and check_ptrs_kv
            else all(x.storage_offset() == 0 for i, x in enumerate([k, v]))
5178
        )
5179

5180
        if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
5181
            # sb3hd, bs3hd, t3hd
5182
            # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv
5183
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
5184
        elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
5185
            # sbh3d, bsh3d, th3d
5186
            # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv
5187
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
5188
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
5189
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
5190
5191
5192
            # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv
            # q and kv may be disjoint or consecutive in memory, and when consecutive, they may
            # have the same data pointer, i.e. check_ptrs_qkv=True
5193
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
5194
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
5195
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
5196
5197
5198
            # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv
            # q and kv may be disjoint or consecutive in memory, and when consecutive, they may
            # have the same data pointer, i.e. check_ptrs_qkv=True
5199
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
5200
5201
5202
5203
5204
        elif (
            check_strides_kv
            and check_shapes_kv
            and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk)
        ):
5205
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
5206
5207
5208
            # three chunks of memory, q, k and v, which may be disjoint or consecutive, and
            # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
            # check_ptrs_qk=True or check_ptrs_kv=True
5209
            qkv_layout = "_".join(list([qkv_format]) * 3)
5210
        else:
5211
            qkv_layout = "not_supported"
5212
5213
5214
5215

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
5216
    if qkv_layout == "not_supported":
5217
5218
5219
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
5220
    if qkv_layout == "not_supported":
5221
        raise RuntimeError("The provided qkv memory layout is not supported!")
5222

5223
    return qkv_layout, q, k, v
5224

5225

5226
def check_set_window_size(
5227
5228
5229
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
5230
5231
5232
5233
5234
5235
5236
5237
    """Check if sliding window size is compliant with attention mask type.
    If not, set it to the appropriate size.

         attn_mask_type                              |   window_size
    -------------------------------------------------------------------------
    no_mask, padding, arbitrary                      | (-1, -1) or (>=0, >=0)
    causal, padding_causal                           | (-1,  0) or (>=0, 0)
    causal_bottom_right, padding_causal_bottom_right | (-1,  0) or (>=0, 0)
5238
    """
5239
    orig_window_size = window_size
5240
    if "causal" in attn_mask_type:
5241
        if orig_window_size is None:
5242
            window_size = (-1, 0)
5243
5244
5245
        elif orig_window_size == (-1, -1) or (
            orig_window_size[0] >= 0 and orig_window_size[1] != 0
        ):
5246
5247
5248
5249
            window_size = (orig_window_size[0], 0)
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
5250
        elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
5251
5252
5253
5254
            assert False, (
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
    elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
5255
5256
5257
        if orig_window_size is None:
            window_size = (-1, -1)
        elif orig_window_size == (-1, 0):
5258
            window_size = (-1, -1)
5259
5260
5261
            warnings.warn(
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
5262
        elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
5263
5264
5265
5266
5267
            assert False, (
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
    else:
        assert False, "Invalid attn_mask_type: " + attn_mask_type
5268
    return window_size
5269

5270

5271
class FlashAttention(torch.nn.Module):
5272
    """Dot product attention, using HazyResearch flash-attn package:
5273
    https://github.com/Dao-AILab/flash-attention
5274
5275
5276
5277
    """

    def __init__(
        self,
5278
        softmax_scale: float,
5279
5280
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
5281
5282
        attention_type: str = "self",
        layer_number: Optional[int] = None,
5283
        deterministic: bool = False,
5284
5285
5286
    ) -> None:
        super().__init__()

5287
5288
5289
5290
5291
5292
5293
        if _flash_attn_is_installed:
            assert (
                _flash_attn_version >= _flash_attn_version_required
            ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
            assert (
                _flash_attn_version <= _flash_attn_max_version
            ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
5294

5295
        self.softmax_scale = softmax_scale
5296
5297
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
5298
5299
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
5300
        self.deterministic = deterministic
5301
5302
5303
5304
        self.logger = logging.getLogger("FlashAttention")
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
5305
5306
5307
5308
5309
5310

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5311
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5312
5313
5314
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5315
5316
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5317
        attn_mask_type: str = "causal",
5318
        window_size: Optional[Tuple[int, int]] = None,
5319
        alibi_slopes: Optional[torch.Tensor] = None,
5320
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
5321
        cp_global_ranks: List[int] = None,
5322
        cp_stream: torch.cuda.Stream = None,
5323
        cp_comm_type: str = "p2p",
5324
5325
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
5326
5327
5328
    ) -> torch.Tensor:
        """flash-attn fprop"""

5329
5330
5331
5332
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
5333
5334
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
5335
        ), "FlashAttention currently only supports CUDA tensors."
5336
5337
        assert (
            qkv_layout in QKVLayouts
5338
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
5339

5340
5341
5342
5343
5344
5345
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
5346
        context_parallel = cp_size > 1
5347

5348
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
5349

5350
5351
5352
5353
5354
5355
5356
5357
5358
5359
5360
5361
5362
        if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
            if qkv_format == "sbhd":
                # For now just 128, will make it more general in the future
                if (
                    query_layer.shape[-1] == 128
                    and query_layer.shape[0] * query_layer.shape[1] >= 512
                    and qkv_layout == "sbh3d"
                ):
                    query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                        query_layer, key_layer, value_layer
                    )
                else:
                    query_layer, key_layer, value_layer = [
5363
                        x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
5364
                    ]
5365
            if context_parallel:
5366
                query_layer, key_layer, value_layer = [
5367
5368
5369
5370
5371
                    x.contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        else:
            if qkv_format == "sbhd":
                query_layer._data, key_layer._data, value_layer._data = [
5372
                    x.transpose(0, 1)
5373
5374
                    for x in (query_layer._data, key_layer._data, value_layer._data)
                ]
5375
5376
5377
5378
                query_layer, key_layer, value_layer = [
                    Float8Tensor.make_like(x, data=x._data)
                    for x in (query_layer, key_layer, value_layer)
                ]
5379
            if context_parallel:
5380
5381
                query_layer._data, key_layer._data, value_layer._data = [
                    x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
5382
                ]
5383

5384
        batch_size = query_layer.shape[0]
5385

5386
        if qkv_format in ["sbhd", "bshd"]:
5387
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
5388
5389
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
5390
5391
5392

            if "padding" in attn_mask_type:
                assert not context_parallel, "Padding mask not supported with context parallelism!"
5393
5394
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
5395
                    x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
5396
5397
5398
5399
5400
5401
5402
                    for x in [query_layer, key_layer, value_layer]
                ]

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
5403
                    if cu_seqlens_q is None:
5404
5405
5406
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
5407
5408
5409
5410
5411
5412
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
5413
5414
                    )
                else:
5415
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
5416
5417
5418
5419
5420
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
5421
5422
5423
5424
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
5425
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
5426
            else:
5427
5428
5429
5430
5431
5432
5433
5434
5435
5436
5437
5438
5439
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
5440
5441
5442
5443
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
5444
5445
5446
5447
5448
5449
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
5450

5451
5452
5453
        if context_parallel and all(
            not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
        ):
5454
5455
5456
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
5457
            with self.attention_dropout_ctx():
5458
                output = attn_forward_func_with_cp(
5459
5460
5461
5462
5463
5464
5465
5466
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
5467
5468
                    cu_seqlens_q if qkv_format == "thd" else None,
                    cu_seqlens_kv if qkv_format == "thd" else None,
5469
                    self.attention_dropout if self.training else 0.0,
5470
5471
5472
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
5473
                    cp_comm_type,
5474
                    softmax_scale=self.softmax_scale,
5475
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
5476
                    attn_mask_type=attn_mask_type,
5477
                    deterministic=self.deterministic,
5478
                    window_size=window_size,
5479
5480
                )
        else:
5481
5482

            from .cpu_offload import CPUOffloadEnabled
5483

5484
5485
5486
5487
5488
5489
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

5490
            with self.attention_dropout_ctx():
5491
                fa_optional_forward_kwargs = {}
5492
5493
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
5494
5495
5496
5497
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
5498
5499
5500
5501
                fa_optional_forward_args_thd = []
                if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
                    func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
                else:
5502
5503
                    if _flash_attn_2_5_7_plus:
                        fa_optional_forward_kwargs["block_table"] = None
5504
5505
5506
5507
5508
5509
5510
5511
5512
5513
                    func = (
                        flash_attn_varlen_func
                        if not _use_flash_attn_3
                        else flash_attn_varlen_func_v3
                    )
                    fa_optional_forward_args_thd.append(cu_seqlens_q)
                    fa_optional_forward_args_thd.append(cu_seqlens_kv)
                    fa_optional_forward_args_thd.append(max_seqlen_q)
                    fa_optional_forward_args_thd.append(max_seqlen_kv)
                if _use_flash_attn_3:
5514
5515
5516
                    fa_3_optional_forward_kwargs = {}
                    fa_3_optional_forward_kwargs["window_size"] = window_size
                    fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
5517
                    activation_dtype = query_layer.dtype
5518
5519
5520
                    if fp8:
                        fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                        torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
5521
5522
5523
5524
5525
5526
5527
5528
5529
5530
5531

                        def convert_to_torch_float8(tensor, dtype):
                            out = torch.Tensor().to(device=tensor.device, dtype=dtype)
                            out.set_(
                                tensor._data.untyped_storage(),
                                tensor._data.storage_offset(),
                                tensor._data.shape,
                                tensor._data.stride(),
                            )
                            return out

5532
5533
5534
5535
5536
5537
                        # "fp8_mha" decides outputs in fp8, while inputs are inferred from
                        # the real dtype
                        assert isinstance(key_layer, query_layer.__class__) and isinstance(
                            value_layer, query_layer.__class__
                        ), "q, k, and v must have the same type."
                        if isinstance(query_layer, Float8Tensor):
5538
5539
5540
                            fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
                        else:
                            query_layer, key_layer, value_layer = (
5541
5542
                                Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward)
                                for x in [query_layer, key_layer, value_layer]
5543
                            )
5544
5545
5546
5547
5548
5549
5550
5551
5552
5553
5554
5555
5556
5557
5558
5559
5560
5561
5562
5563
5564
                        fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv
                        fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv
                        fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv
                        query_layer, key_layer, value_layer = (
                            convert_to_torch_float8(x, torch_dtype)
                            for x in [query_layer, key_layer, value_layer]
                        )
                    try:
                        output, _ = func(
                            query_layer,
                            key_layer,
                            value_layer,
                            *fa_optional_forward_args_thd,
                            softmax_scale=self.softmax_scale,
                            causal="causal" in attn_mask_type,
                            **fa_3_optional_forward_kwargs,
                        )
                    except TypeError as e:
                        if _flash_attn_3_0_0_beta:
                            e.args = (
                                e.args[0]
5565
                                + ". Please update your flash-attn v3 (beta) installation as it "
5566
5567
5568
5569
5570
                                + "may have added more supported arguments to its API. \n"
                                + _flash_attn_3_installation_steps,
                            ) + e.args[1:]
                        raise

5571
5572
5573
5574
5575
5576
5577
5578
5579
5580
5581
5582
5583
5584
5585
5586
5587
5588
5589
5590
5591
5592
5593
5594
5595
5596
                    if fp8 and fp8_meta["recipe"].fp8_mha:
                        output = cast_to_fp8(
                            output,
                            fp8_meta["scaling_fwd"],
                            META_O,
                            fp8_dtype_forward,
                        )
                        output = Float8Tensor(
                            data=output,
                            fp8_meta=fp8_meta,
                            fp8_meta_forward=True,
                            fp8_meta_index=META_O,
                            fp8_dtype=fp8_dtype_forward,
                            dtype=activation_dtype,
                        )
                else:
                    output = func(
                        query_layer,
                        key_layer,
                        value_layer,
                        *fa_optional_forward_args_thd,
                        self.attention_dropout if self.training else 0.0,
                        softmax_scale=self.softmax_scale,
                        causal="causal" in attn_mask_type,
                        **fa_optional_forward_kwargs,
                    )
5597

5598
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
5599
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
5600

5601
        if qkv_format == "sbhd":
5602
            # (bs)hd -> bs(hd) -> sb(hd)
5603
            if fp8 and fp8_meta["recipe"].fp8_mha:
5604
5605
5606
5607
5608
5609
                output = Float8Tensor.make_like(
                    output,
                    data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
                    .transpose(0, 1)
                    .contiguous(),
                )
5610
            else:
5611
                output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
5612
        elif qkv_format == "bshd":
5613
            # (bs)hd -> bs(hd)
5614
            output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
5615
        elif qkv_format == "thd":
5616
            # thd -> t(hd)
5617
            output = output.reshape(output.shape[0], -1)
5618

5619
        return output.contiguous()
5620

5621

5622
def _combine_tensors(
5623
5624
5625
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
5626
5627
5628
5629
5630
5631
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
5632
    new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
5633
    if isinstance(tensors[0], Float8Tensor):
5634
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
5635
5636
5637
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
5638
5639
5640
5641
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor)
5642
    else:
5643
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
5644
        combined_tensor.set_(
5645
5646
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
        )
5647
5648

    return combined_tensor
5649

5650

5651
5652
5653
5654
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
5655
5656
5657
5658
5659
    def forward(
        ctx,
        is_training,
        max_seqlen,
        cu_seqlens,
5660
        cu_seqlens_padded,
5661
5662
5663
5664
5665
5666
5667
5668
5669
        qkv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
5670
        window_size,
5671
5672
5673
5674
5675
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
5676
        deterministic,
5677
    ):
5678
        # pylint: disable=missing-function-docstring
5679
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
5680
5681
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
5682
        if fp8:
5683
5684
            is_input_fp8 = isinstance(qkv, Float8Tensor)
            if is_input_fp8:
5685
5686
5687
5688
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
5689
            qkv_group = len(qkv_layout.split("_"))
5690
5691
5692
5693
            assert (
                qkv_group == 1
            ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}."
            if is_input_fp8:
5694
5695
5696
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
5697
5698
5699
                qkv_fp8 = cast_to_fp8(
                    qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(qkv.shape)
5700
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
5701
5702
5703
5704
5705
5706
5707
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
5708
                cu_seqlens_padded,
5709
5710
5711
5712
5713
5714
5715
5716
5717
5718
5719
5720
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
5721
5722
5723
5724
5725
5726
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5727
                window_size,
5728
5729
                rng_gen,
            )
5730
            if is_output_fp8:
5731
5732
                out_ret = Float8Tensor(
                    data=out_fp8,
5733
5734
5735
5736
5737
5738
5739
5740
5741
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
5742
5743
5744
5745
5746
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
5747
            out_save = out_ret
5748
5749
5750
5751
5752
5753
5754
5755
5756
5757
5758
5759
5760
5761
5762
5763
5764
5765
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                if is_input_fp8:
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                    qkv = cast_from_fp8(
                        qkv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[qkv.dtype],
                    ).view(qkv.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                        fp8_meta["scaling_fwd"],
                        META_O,
                        fp8_dtype_forward,
                        qkv_dtype,
                    ).view(out_fp8.shape)
5766
5767
5768
            fp8_tensors = (
                qkv_fp8,
                out_fp8,
5769
                fp8_meta["scaling_fwd"].scale.clone(),
5770
5771
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
5772
5773
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
5774
5775
5776
5777
5778
5779
5780
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
5781
                cu_seqlens_padded,
5782
5783
5784
5785
5786
5787
5788
5789
5790
5791
5792
5793
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
5794
5795
5796
5797
5798
5799
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5800
                window_size,
5801
5802
                rng_gen,
            )
5803
5804
5805
5806
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
5807
5808
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
5809
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
5810
        ctx.save_for_backward(
5811
            *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
5812
        )
5813
        ctx.fp8_meta = fp8_meta
5814
5815
5816
5817
5818
5819
5820
5821
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
5822
        ctx.window_size = window_size
5823
        ctx.fused_attention_backend = (
5824
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
5825
        )
5826
        ctx.use_FAv2_bwd = use_FAv2_bwd
5827
        ctx.deterministic = deterministic
5828

5829
        return out_ret
5830
5831
5832

    @staticmethod
    def backward(ctx, d_out):
5833
        # pylint: disable=missing-function-docstring
5834
        if ctx.is_output_fp8:
5835
5836
5837
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
5838
5839
5840
            d_out_f8tensor = d_out
            d_out = d_out._data

5841
        d_out = d_out.contiguous()
5842
5843
5844
5845
        (
            qkv,
            out,
            cu_seqlens,
5846
            cu_seqlens_padded,
5847
5848
5849
5850
5851
5852
            qkv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
5853
        rest = [None]
5854
5855
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
5856
        if ctx.use_FAv2_bwd:
5857
            softmax_lse, rng_state = aux_ctx_tensors
5858
            dqkv = torch.empty_like(qkv)
5859
5860
5861
            d_out, q, k, v, out = [
                maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
            ]
5862
            flash_attn_cuda_bwd(
5863
5864
5865
5866
5867
5868
5869
5870
5871
5872
5873
5874
5875
5876
5877
5878
5879
5880
5881
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dqkv[:, 0],
                dqkv[:, 1],
                dqkv[:, 2],
                cu_seqlens,
                cu_seqlens,
                ctx.max_seqlen,
                ctx.max_seqlen,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
5882
            )
5883
            dqkv = dqkv[..., : d_out.shape[-1]]
5884
        else:
5885
5886
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
5887
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
5888
                    fp8_dtype_backward = get_fp8_te_dtype(
5889
5890
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
5891
                    if ctx.is_output_fp8:
5892
                        d_out_fp8 = d_out
5893
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
5894
5895
5896
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
5897
5898
5899
5900
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
5901
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
5902
5903
5904
5905
5906
5907
5908
5909
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
5910
                        ctx.fused_attention_backend,
5911
                        cu_seqlens_padded,
5912
5913
5914
5915
5916
5917
5918
5919
5920
5921
5922
5923
5924
5925
5926
5927
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5928
5929
                        ctx.window_size,
                        ctx.deterministic,
5930
                    )
5931
                    if ctx.is_input_fp8:
5932
5933
                        dqkv = Float8Tensor(
                            data=dqkv_fp8,
5934
5935
5936
5937
5938
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
5939
                        )
5940
                    else:
5941
5942
5943
5944
5945
5946
5947
5948
5949
5950
                        dqkv_c_fp8 = dqkv_fp8.view(
                            -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                        )
                        dqkv = cast_from_fp8(
                            dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dqkv_fp8.shape)
5951
5952
5953
5954
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
5955
5956
5957
5958
5959
5960
5961
5962
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
5963
                        ctx.fused_attention_backend,
5964
                        cu_seqlens_padded,
5965
5966
5967
5968
5969
5970
5971
5972
5973
5974
5975
5976
5977
5978
5979
5980
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5981
5982
                        ctx.window_size,
                        ctx.deterministic,
5983
                    )
5984

5985
5986
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
5987
5988
5989
5990
5991
5992
5993
5994
5995
5996
5997
5998
5999
6000
6001
6002
6003
6004
6005
6006
6007
            return (
                None,
                None,
                None,
                None,
                dqkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
6008
6009
                None,
                None,
6010
            )
6011
        # else, return (dqkv, dbias)
6012
6013
6014
6015
6016
6017
6018
6019
6020
6021
6022
6023
6024
6025
6026
6027
6028
6029
6030
6031
6032
        return (
            None,
            None,
            None,
            None,
            dqkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
6033
6034
            None,
            None,
6035
        )
6036

6037

6038
6039
6040
6041
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
6042
6043
6044
6045
6046
6047
6048
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
6049
6050
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
6051
6052
6053
6054
6055
6056
6057
6058
6059
6060
        q,
        kv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
6061
        window_size,
6062
6063
6064
6065
6066
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
6067
        deterministic,
6068
    ):
6069
        # pylint: disable=missing-function-docstring
6070
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
6071
6072
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
6073
        if fp8:
6074
6075
6076
            assert isinstance(kv, q.__class__), "q and kv must have the same type."
            is_input_fp8 = isinstance(q, Float8Tensor)
            if is_input_fp8:
6077
6078
6079
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
6080
            if is_input_fp8:
6081
6082
6083
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6084
6085
                qkv_group = len(qkv_layout.split("_"))
                assert qkv_group == 2, (
6086
6087
                    "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, "
                    f"but found {qkv_layout}."
6088
6089
6090
6091
                )
                q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
                    q.shape
                )
6092
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
6093
6094
6095
                kv_fp8 = cast_to_fp8(
                    kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(kv.shape)
6096
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
6097
6098
6099
6100
6101
6102
6103
6104
6105
6106
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                kv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
6107
6108
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6109
6110
6111
6112
6113
6114
6115
6116
6117
6118
6119
6120
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
6121
6122
6123
6124
6125
6126
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6127
                window_size,
6128
6129
                rng_gen,
            )
6130
            if is_output_fp8:
6131
6132
                out_ret = Float8Tensor(
                    data=out_fp8,
6133
6134
6135
6136
6137
6138
6139
6140
6141
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
6142
6143
6144
6145
6146
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
6147
            out_save = out_ret
6148
6149
6150
6151
6152
6153
6154
6155
6156
6157
6158
6159
6160
6161
6162
6163
6164
6165
6166
6167
6168
6169
6170
6171
6172
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                if is_input_fp8:
                    q = cast_from_fp8(
                        q._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                    kv = cast_from_fp8(
                        kv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[kv.dtype],
                    ).view(kv.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                        fp8_meta["scaling_fwd"],
                        META_O,
                        fp8_dtype_forward,
                        qkv_dtype,
                    ).view(out_fp8.shape)
6173
6174
6175
6176
            fp8_tensors = (
                q_fp8,
                kv_fp8,
                out_fp8,
6177
                fp8_meta["scaling_fwd"].scale.clone(),
6178
6179
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
6180
6181
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
6182
6183
6184
6185
6186
6187
6188
6189
6190
6191
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                kv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
6192
6193
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6194
6195
6196
6197
6198
6199
6200
6201
6202
6203
6204
6205
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
6206
6207
6208
6209
6210
6211
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6212
                window_size,
6213
6214
                rng_gen,
            )
6215
6216
6217
6218
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
6219
6220
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
6221
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
6222
6223
6224
6225
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
6226
6227
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6228
6229
6230
            *fp8_tensors,
            *aux_ctx_tensors,
        )
6231
        ctx.fp8_meta = fp8_meta
6232
6233
6234
6235
6236
6237
6238
6239
6240
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
6241
        ctx.window_size = window_size
6242
        ctx.fused_attention_backend = (
6243
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
6244
        )
6245
        ctx.use_FAv2_bwd = use_FAv2_bwd
6246
        ctx.deterministic = deterministic
6247

6248
        return out_ret
6249
6250
6251

    @staticmethod
    def backward(ctx, d_out):
6252
        # pylint: disable=missing-function-docstring
6253
        if ctx.is_output_fp8:
6254
6255
6256
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
6257
6258
6259
            d_out_f8tensor = d_out
            d_out = d_out._data

6260
        d_out = d_out.contiguous()
6261
6262
6263
6264
6265
6266
        (
            q,
            kv,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
6267
6268
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6269
6270
6271
6272
6273
6274
6275
            q_fp8,
            kv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
6276
        rest = [None]
6277
6278
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
6279
        if ctx.use_FAv2_bwd:
6280
            softmax_lse, rng_state = aux_ctx_tensors
6281
6282
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
6283
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
6284
            flash_attn_cuda_bwd(
6285
6286
6287
6288
6289
6290
6291
6292
6293
6294
6295
6296
6297
6298
6299
6300
6301
6302
6303
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dkv[:, 0],
                dkv[:, 1],
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
6304
            )
6305
6306
            dq = dq[..., : d_out.shape[-1]]
            dkv = dkv[..., : d_out.shape[-1]]
6307
        else:
6308
6309
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
6310
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
6311
                    fp8_dtype_backward = get_fp8_te_dtype(
6312
6313
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
6314
                    if ctx.is_output_fp8:
6315
                        d_out_fp8 = d_out
6316
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
6317
6318
6319
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
6320
6321
6322
6323
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
6324
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
6325
6326
6327
6328
6329
6330
6331
6332
6333
6334
6335
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        kv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
6336
                        ctx.fused_attention_backend,
6337
6338
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6339
6340
6341
6342
6343
6344
6345
6346
6347
6348
6349
6350
6351
6352
6353
6354
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6355
6356
                        ctx.window_size,
                        ctx.deterministic,
6357
                    )
6358
                    if ctx.is_input_fp8:
6359
6360
                        dq = Float8Tensor(
                            data=dq_fp8,
6361
6362
6363
6364
6365
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6366
6367
6368
                        )
                        dkv = Float8Tensor(
                            data=dkv_fp8,
6369
6370
6371
6372
6373
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6374
                        )
6375
6376
6377
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6378
6379
6380
6381
6382
6383
6384
6385
6386
6387
6388
6389
6390
6391
6392
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(
                            -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                        )
                        dkv = cast_from_fp8(
                            dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dkv_fp8.shape)
6393
6394
6395
6396
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
6397
6398
6399
6400
6401
6402
6403
6404
6405
6406
6407
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        kv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
6408
                        ctx.fused_attention_backend,
6409
6410
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6411
6412
6413
6414
6415
6416
6417
6418
6419
6420
6421
6422
6423
6424
6425
6426
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6427
6428
                        ctx.window_size,
                        ctx.deterministic,
6429
                    )
6430

6431
6432
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
6433
6434
6435
6436
6437
6438
6439
6440
6441
6442
6443
6444
6445
6446
6447
6448
6449
6450
6451
6452
6453
6454
6455
6456
6457
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
6458
6459
                None,
                None,
6460
            )
6461
        # else, return (dqkv, dbias)
6462
6463
6464
6465
6466
6467
6468
6469
6470
6471
6472
6473
6474
6475
6476
6477
6478
6479
6480
6481
6482
6483
6484
6485
6486
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
6487
6488
            None,
            None,
6489
6490
        )

6491

6492
6493
6494
6495
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
6496
6497
6498
6499
6500
6501
6502
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
6503
6504
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
6505
6506
6507
6508
6509
6510
6511
6512
6513
6514
6515
        q,
        k,
        v,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
6516
        window_size,
6517
6518
6519
6520
6521
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
6522
        deterministic,
6523
    ):
6524
        # pylint: disable=missing-function-docstring
6525
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
6526
6527
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
6528
6529
6530
        if fp8:
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
6531
6532
6533
6534
6535
            assert isinstance(k, q.__class__) and isinstance(
                v, q.__class__
            ), "q, k, and v must have the same type."
            is_input_fp8 = isinstance(q, Float8Tensor)
            if is_input_fp8:
6536
6537
6538
6539
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6540
                qkv_group = len(qkv_layout.split("_"))
6541
                if qkv_group == 1:
6542
6543
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
6544
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
6545
6546
6547
6548
                    qkv_fp8 = cast_to_fp8(
                        qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1])
6549
6550
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
6551
6552
6553
6554
6555
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
6556
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
6557
6558
6559
6560
                    kv_fp8 = cast_to_fp8(
                        kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1])
6561
6562
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
6563
6564
6565
6566
6567
6568
6569
6570
6571
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    k_fp8 = cast_to_fp8(
                        k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(k.shape)
                    v_fp8 = cast_to_fp8(
                        v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(v.shape)
6572
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
6573
6574
6575
6576
6577
6578
6579
6580
6581
6582
6583
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
6584
6585
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6586
6587
6588
6589
6590
6591
6592
6593
6594
6595
6596
6597
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
6598
6599
6600
6601
6602
6603
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6604
                window_size,
6605
6606
                rng_gen,
            )
6607
            if is_output_fp8:
6608
6609
                out_ret = Float8Tensor(
                    data=out_fp8,
6610
6611
6612
6613
6614
6615
6616
6617
6618
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
6619
6620
6621
6622
6623
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
6624
6625
            out_save = out_ret

6626
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
6627
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6628
6629
6630
6631
6632
6633
6634
6635
6636
6637
6638
6639
6640
6641
6642
6643
6644
6645
6646
6647
6648
6649
6650
6651
6652
6653
6654
6655
6656
6657
6658
6659
6660
6661
6662
6663
6664
6665
6666
6667
6668
6669
6670
6671
6672
6673
6674
6675
6676
6677
6678
6679
6680
6681
6682
6683
6684
6685
6686
6687
                if is_input_fp8:
                    qkv_group = len(qkv_layout.split("_"))
                    if qkv_group == 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                        qkv_no_fp8 = cast_from_fp8(
                            qkv_c._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[qkv.dtype],
                        ).view(qkv.shape)
                        q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1])
                        q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                    if qkv_group == 2:
                        q = cast_from_fp8(
                            q._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[q.dtype],
                        ).view(q.shape)
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                        kv_no_fp8 = cast_from_fp8(
                            kv_c._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[kv.dtype],
                        ).view(kv.shape)
                        k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1])
                        k, v = [x.squeeze(dim) for x in [k, v]]
                    if qkv_group == 3:
                        q = cast_from_fp8(
                            q._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[q.dtype],
                        ).view(q.shape)
                        k = cast_from_fp8(
                            k._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[k.dtype],
                        ).view(k.shape)
                        v = cast_from_fp8(
                            v._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[v.dtype],
                        ).view(v.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
6688
                        fp8_meta["scaling_fwd"],
6689
                        META_O,
6690
                        fp8_dtype_forward,
6691
6692
                        qkv_dtype,
                    ).view(out_fp8.shape)
6693
6694
6695
6696
6697
6698

            fp8_tensors = (
                q_fp8,
                k_fp8,
                v_fp8,
                out_fp8,
6699
                fp8_meta["scaling_fwd"].scale.clone(),
6700
6701
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
6702
6703
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd(
6704
6705
6706
6707
6708
6709
6710
6711
6712
6713
6714
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
6715
6716
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6717
6718
6719
6720
6721
6722
6723
6724
6725
6726
6727
6728
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
6729
6730
6731
6732
6733
6734
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6735
                window_size,
6736
6737
                rng_gen,
            )
6738
6739
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
6740

6741
6742
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

6743
        from .cpu_offload import CPUOffloadEnabled
6744

6745
        if CPUOffloadEnabled:
6746
6747
6748
6749
6750
6751
6752
            if ctx.fp8:
                tensor_list = fp8_tensors
            else:
                tensor_list = [q, k, v, out_save]

            tensor_list.extend(aux_ctx_tensors)

6753
            qkv_layout = "sbhd_sbhd_sbhd"
6754
6755
6756
6757
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

6758
6759
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
6760
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
6761
6762
6763
6764
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
6765
6766
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6767
6768
6769
            *fp8_tensors,
            *aux_ctx_tensors,
        )
6770
        ctx.fp8_meta = fp8_meta
6771
6772
6773
6774
6775
6776
6777
6778
6779
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
6780
        ctx.window_size = window_size
6781
        ctx.fused_attention_backend = (
6782
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
6783
        )
6784
        ctx.use_FAv2_bwd = use_FAv2_bwd
6785
        ctx.deterministic = deterministic
6786

6787
        return out_ret
6788
6789
6790

    @staticmethod
    def backward(ctx, d_out):
6791
        # pylint: disable=missing-function-docstring
6792
        if ctx.is_output_fp8:
6793
6794
6795
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
6796
6797
6798
            d_out_f8tensor = d_out
            d_out = d_out._data

6799
        d_out = d_out.contiguous()
6800
6801
6802
6803
6804
6805
6806
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
6807
6808
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6809
6810
6811
6812
6813
6814
6815
6816
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
6817
6818
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
6819
        rest = [None]
6820
        if ctx.use_FAv2_bwd:
6821
            softmax_lse, rng_state = aux_ctx_tensors
6822
6823
6824
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
6825
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
6826
            flash_attn_cuda_bwd(
6827
6828
6829
6830
6831
6832
6833
6834
6835
6836
6837
6838
6839
6840
6841
6842
6843
6844
6845
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
6846
            )
6847
6848
6849
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
6850
        else:
6851
6852
6853
6854
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
6855
6856
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
6857
                    if ctx.is_output_fp8:
6858
                        d_out_fp8 = d_out
6859
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
6860
6861
6862
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
6863
6864
6865
6866
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
6867
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
6868
6869
6870
6871
6872
6873
6874
6875
6876
6877
6878
6879
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
6880
                        ctx.fused_attention_backend,
6881
6882
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6883
6884
6885
6886
6887
6888
6889
6890
6891
6892
6893
6894
6895
6896
6897
6898
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6899
6900
                        ctx.window_size,
                        ctx.deterministic,
6901
                    )
6902

6903
                    if ctx.is_input_fp8:
6904
6905
                        dq = Float8Tensor(
                            data=dq_fp8,
6906
6907
6908
6909
6910
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6911
6912
6913
                        )
                        dk = Float8Tensor(
                            data=dk_fp8,
6914
6915
6916
6917
6918
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6919
6920
6921
                        )
                        dv = Float8Tensor(
                            data=dv_fp8,
6922
6923
6924
6925
6926
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6927
                        )
6928
                    else:
6929
                        qkv_group = len(ctx.qkv_layout.split("_"))
6930
                        if qkv_group == 1:
6931
6932
6933
6934
6935
6936
6937
6938
6939
6940
6941
6942
6943
                            dim = ctx.qkv_layout.find("3")
                            dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(
                                -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                            )
                            dqkv = cast_from_fp8(
                                dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1])
6944
6945
6946
6947
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6948
6949
6950
6951
6952
6953
6954
6955
6956
6957
6958
6959
6960
6961
6962
6963
6964
6965
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
                            dkv = cast_from_fp8(
                                dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1])
6966
6967
6968
6969
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6970
6971
6972
6973
6974
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
6975
6976
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
6977
6978
6979
6980
6981
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dk_fp8.shape)
6982
6983
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
6984
6985
6986
6987
6988
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dv_fp8.shape)
6989
6990
6991
6992
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
6993
6994
6995
6996
6997
6998
6999
7000
7001
7002
7003
7004
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
7005
                        ctx.fused_attention_backend,
7006
7007
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
7008
7009
7010
7011
7012
7013
7014
7015
7016
7017
7018
7019
7020
7021
7022
7023
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
7024
7025
                        ctx.window_size,
                        ctx.deterministic,
7026
                    )
7027

7028
7029
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
7030
7031
7032
7033
7034
7035
7036
7037
7038
7039
7040
7041
7042
7043
7044
7045
7046
7047
7048
7049
7050
7051
7052
7053
7054
7055
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
7056
7057
                None,
                None,
7058
            )
7059
        # else, return (dqkv, dbias)
7060
7061
7062
7063
7064
7065
7066
7067
7068
7069
7070
7071
7072
7073
7074
7075
7076
7077
7078
7079
7080
7081
7082
7083
7084
7085
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
7086
7087
            None,
            None,
7088
        )
7089

7090

7091
class FusedAttention(torch.nn.Module):
7092
7093
7094
7095
7096
7097
7098
7099
7100
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

7101
7102
7103
7104
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
7105
    | attn_type     | self/cross              | self/cross                     |
7106
    | qkv_layout    |                         |                                |
7107
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
7108
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
7109
7110
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
7111
7112
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
7113
    | dropout       | yes                     | yes                            |
7114
7115
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
7116
    | output dtype  | fp16/bf16               | fp16/bf16                      |
7117
7118
7119
7120
    """

    def __init__(
        self,
7121
        softmax_scale: float,
7122
7123
7124
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
7125
7126
        layer_number: Optional[int] = None,
        deterministic: bool = False,
7127
7128
7129
    ) -> None:
        super().__init__()

7130
        self.softmax_scale = softmax_scale
7131
7132
7133
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
7134
7135
7136
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
7137
        self.layer_number = 1 if layer_number is None else layer_number
7138
        self.deterministic = deterministic
7139

7140
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
7141
7142
            """
            Temporarily remove fused_attention._extra_state as a missing key
7143
            or an unexpected key when loading Transformer Engine checkpoints.
7144
7145
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
7146
            phased out in Transformer Engine 2.0.
7147
7148
            """
            for key in incompatible_keys.missing_keys:
7149
                if "fused_attention._extra_state" in key:
7150
                    incompatible_keys.missing_keys.remove(key)
7151
7152
7153
7154
7155
7156
7157
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
7158

7159
7160
        self.register_load_state_dict_post_hook(remove_extra_states_check)

7161
    @no_torch_dynamo()
7162
7163
7164
7165
7166
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
7167
7168
7169
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
7170
7171
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
7172
7173
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
7174
        attn_mask_type: str = "causal",
7175
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
7176
        window_size: Optional[Tuple[int, int]] = None,
7177
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
7178
7179
7180
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
7181
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
7182
7183
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
7184
        cp_comm_type: str = "p2p",
7185
7186
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
7187
7188
    ) -> torch.Tensor:
        """fused attention fprop"""
7189
7190
7191
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
7192
7193
7194
7195
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
7196
7197
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
7198
        ), "FusedAttention only supports CUDA tensors."
7199
7200
        assert (
            qkv_layout in QKVLayouts
7201
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
7202

7203
7204
7205
7206
7207
7208
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
7209
        context_parallel = cp_size > 1
7210

7211
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
7212

7213
7214
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
7215
                batch_size, max_seqlen_q, max_seqlen_kv = (
7216
7217
7218
7219
7220
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
7221
                batch_size, max_seqlen_q, max_seqlen_kv = (
7222
7223
7224
7225
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
7226
7227
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
7228
            if "padding" in attn_mask_type:
7229
7230
                assert not context_parallel, "Padding mask not supported with context parallelism!"

7231
7232
7233
7234
7235
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
7236
                    if self.attention_type == "self":
7237
7238
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
7239
                    else:
7240
7241
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
7242
            else:
7243
7244
7245
7246
7247
7248
7249
7250
7251
7252
7253
7254
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
7255
7256
7257
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
7258
7259
7260
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
7261
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
7262

7263
        if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None):
7264
7265
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
7266
7267
7268

        qkv_dtype = TE_DType[query_layer.dtype]

7269
7270
7271
7272
7273
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
7274

7275
7276
7277
7278
7279
7280
7281
7282
7283
7284
7285
        if fp8:
            assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                " is required for FP8 attention!"
            )
            assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
            assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
                "Amax reduction across TP+CP group is necessary when using context parallelism with"
                " FP8!"
            )

7286
        if context_parallel:
7287
            assert (
7288
7289
                fp8
                or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
7290
7291
7292
7293
7294
7295
7296
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
7297
7298
7299
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
7300
7301
7302
7303
7304
7305
7306
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
7307
7308
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
7309
                    self.attention_dropout if self.training else 0.0,
7310
7311
7312
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
7313
                    cp_comm_type,
7314
                    softmax_scale=self.softmax_scale,
7315
                    qkv_format=qkv_format,
7316
                    attn_mask_type=attn_mask_type,
7317
7318
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
7319
                    deterministic=self.deterministic,
7320
                    use_fused_attention=True,
7321
                    window_size=window_size,
7322
7323
                    fp8=fp8,
                    fp8_meta=fp8_meta,
7324
7325
                )
        else:
7326
7327
7328
7329
7330
7331
7332
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
7333
7334
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
7335
7336
7337
7338
7339
7340
7341
7342
7343
7344
7345
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
7346
                    window_size,
7347
7348
7349
7350
7351
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
7352
                    self.deterministic,
7353
                )
7354

7355
7356
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
7357
7358


7359
class DotProductAttention(TransformerEngineBaseModule):
7360
7361
7362
7363
7364
7365
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

7366
        Argument :attr:`attention_mask` in the `forward` call is only used when
7367
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
7368
7369
7370

    .. warning::

7371
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
7372
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
7373
7374
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
7375

7376
7377
7378
7379
7380
7381
7382
    .. note::

        Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
        As the FP8 attention support expands from one backend to multiple backends, the location
        of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).


7383
7384
7385
7386
    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
7387
7388
7389
    kv_channels : Union[int, Tuple[int, int]]
                the head size in key and value tensors. If the same, :attr:`kv_channels` can be
                an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
7390
7391
7392
7393
7394
7395
7396
7397
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
7398
7399
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
7400
    attn_mask_type: str, default = `causal`
7401
                   type of attention mask passed into softmax operation, options are "`no_mask`",
7402
7403
7404
7405
7406
7407
7408
7409
7410
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
7411
                   "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
7412
7413
7414
7415
7416
7417
7418
7419
7420
7421
7422
7423
7424
7425
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
7426
7427
7428
7429
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
7430
7431
7432
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
7433
                be overridden by :attr:`window_size` in `forward` as well.
7434
7435
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
7436
7437
7438
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
7439
7440
7441
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
7442
               `h` the number of heads, `d` head size, and `t` the total number of tokens
7443
7444
7445
7446
7447
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
7448
               For that, please use `get_qkv_layout` to gain the layout information.
7449
7450
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
7451
                `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
7452
7453
7454
7455
7456
7457
7458
7459
7460

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
7461
    cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
7462
              context parallel process group.
7463
7464
7465
              ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
              List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
              and cp_group[1] are for a2a and p2p communications respectively.
7466
7467
7468
7469
7470
7471
7472
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
7473
    cp_comm_type : str, default = `p2p`
7474
                  inter-gpu communication type for context parallelism.
7475
                  Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
7476
7477
7478
7479
7480
7481
                  "p2p": Exchange KV chunks with P2P communications in ring topology.
                         P2P is async and can be overlapped with attention compute.
                  "all_gather": All-gather to get full sequence of KV before attention.
                                The all-gather is not async, and cannot be overlapped.
                  "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                         group, and gather to get full sequence of QKV.
7482
7483
7484
                  "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                  across each CP sub-group (e.g., via NVLink), then exchanging KV with
                  p2p between sub-groups (e.g., via IBLink).
7485
7486
7487
7488
7489
    """

    def __init__(
        self,
        num_attention_heads: int,
7490
        kv_channels: Union[int, Tuple[int, int]],
7491
        num_gqa_groups: Optional[int] = None,
7492
        attention_dropout: float = 0.0,
7493
        qkv_format: str = "sbhd",
7494
        attn_mask_type: str = "causal",
7495
        window_size: Optional[Tuple[int, int]] = None,
7496
7497
7498
7499
7500
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
7501
        attention_type: str = "self",
7502
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
7503
        cp_global_ranks: List[int] = None,
7504
        cp_stream: torch.cuda.Stream = None,
7505
        cp_comm_type: str = "p2p",
7506
        softmax_scale: Optional[float] = None,
7507
7508
7509
    ) -> None:
        super().__init__()

7510
        self.logger = logging.getLogger("DotProductAttention")
7511
7512
7513
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
7514
        self.qkv_format = qkv_format
7515
        attn_mask_type = attn_mask_type.replace(",", "_")
7516
7517
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
7518
        self.attn_mask_type = attn_mask_type
7519
        self.window_size = check_set_window_size(attn_mask_type, window_size)
7520
7521
7522
7523
7524
7525
7526
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
7527
        self.get_rng_state_tracker = get_rng_state_tracker
7528
        self.num_attention_heads = num_attention_heads
7529
        self.layer_number = 1 if layer_number is None else layer_number
7530
7531
7532
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
7533
        self.cp_comm_type = cp_comm_type
7534

7535
7536
7537
7538
7539
7540
        self.hidden_size_per_attention_head_k = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[0]
        )
        self.hidden_size_per_attention_head_v = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[1]
        )
7541

7542
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
7543
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
7544

7545
7546
7547
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
7548

7549
        self.rng_states_tracker = None
7550
7551
7552
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
7553
7554
7555
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
7556

7557
        if softmax_scale is None:
7558
7559
7560
            softmax_scale = 1.0 / math.sqrt(
                kv_channels if isinstance(kv_channels, int) else kv_channels[0]
            )
7561

7562
7563
7564
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
7565
        )
7566
7567
7568
7569
7570
7571
7572
7573
7574
7575
7576
7577
7578
7579
7580
7581
7582
7583
7584
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
7585

7586
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
7587
7588
7589
7590

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

7591
7592
7593
7594
7595
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

7596
7597
7598
7599
7600
7601
7602
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
7603

7604
        # Instantiating three types since use of flash-attn and FusedAttention
7605
        # might be ruled out due to forward inputs.
7606
7607
7608
7609
7610
7611
7612
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
7613

7614
        self.unfused_attention = UnfusedDotProductAttention(
7615
7616
7617
7618
            softmax_scale,
            attention_type=attention_type,
            **attn_kwargs,
            layer_number=layer_number,
7619
        )
7620

7621
7622
7623
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
7624
7625
            when loading older Transformer Engine checkpoints. Will phase out
            this hook in Transformer Engine 2.0.
7626
7627
7628
7629
7630
7631
7632
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

7633
7634
7635
7636
7637
7638
7639
7640
7641
7642
7643
7644
7645
7646
7647
7648
7649
7650
7651
7652
7653
7654
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        """
        This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
        metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
        `core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
        <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
        """
        fused_attn_key = False
        dot_product_attn_key = False
        for k in state_dict.keys():
            if "core_attention.fused_attention._extra_state" in k:
                fused_attn_key = True
            if "core_attention._extra_state" in k:
                dot_product_attn_key = True
        if fused_attn_key and not dot_product_attn_key:
            prefix = prefix + "fused_attention."
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )

7655
7656
7657
7658
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
7659
        **forward_kwargs: Dict[str, Any],
7660
7661
7662
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

7663
7664
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
7665
7666
7667

        hidden_states = checkpoint(
            custom_forward,
7668
7669
7670
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
7671
            *forward_args,
7672
            **forward_kwargs,
7673
7674
7675
7676
        )

        return hidden_states

7677
7678
    def set_context_parallel_group(
        self,
7679
        cp_group: Union[dist_group_type, List[dist_group_type], None],
7680
7681
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
7682
        cp_comm_type: str = "p2p",
7683
    ) -> None:
7684
7685
7686
7687
7688
7689
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
7690
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
7691
                  context parallel process group.
7692
7693
7694
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
7695
7696
7697
7698
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
7699
        cp_comm_type : str, default = `p2p`
7700
                      inter-gpu communication type for context parallelism.
7701
                      Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
7702
7703
7704
7705
7706
7707
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
7708
7709
7710
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
7711
        """
7712
7713
7714
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
7715
        self.cp_comm_type = cp_comm_type
7716

7717
    @no_torch_dynamo(recursive=False)
7718
7719
7720
7721
7722
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
7723
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
7724
7725
7726
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
7727
7728
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
7729
7730
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
7731
        attn_mask_type: Optional[str] = None,
7732
        window_size: Optional[Tuple[int, int]] = None,
7733
        checkpoint_core_attention: bool = False,
7734
7735
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
7736
        alibi_slopes: Optional[torch.Tensor] = None,
7737
        fast_zero_fill: bool = True,
7738
        inference_params: Optional[InferenceParams] = None,
7739
        is_first_microbatch: Optional[bool] = None,
7740
7741
7742
7743
7744
7745
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

7746
7747
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
7748

7749
7750
        .. note::

7751
7752
7753
7754
7755
7756
7757
7758
7759
7760
7761
7762
7763
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
7764
            and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
7765
7766
7767
7768
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
7769
7770
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
7771
            optimizations in FusedAttention. When unset, Transformer Engine determines the code path
7772
7773
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
7774

7775
7776
7777
7778
7779
7780
7781
7782
7783
7784
7785
7786
7787
7788
7789
7790
7791
7792
7793
7794
7795
7796
7797
7798
7799
7800
7801
7802
7803
7804
7805
7806
7807
7808
7809
7810
7811
7812
7813
7814
7815
7816
7817
7818
7819
7820
7821
7822
7823
7824
7825
7826
7827
7828
        .. note::
            .. _cu_seqlens note:

            When training data has variable sequence lengths, users have two options.

            1. Manipulate the data and pad all sequences to the same length. Use
               :attr:`qkv_format` = {"bshd", "sbhd"} and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
               (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
               the real sequence length information. For example, a batch of 3 sequences
               [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative
               sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

            2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
               as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed
               without any padding, and the sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

               In certain use cases, a varying number of identifier tokens are inserted between
               sequences. These tokens do not participate in the attention calculation.
               :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
               in such cases to correctly identify the start and end of each sequence in a batch.
               For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and
               :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13]
               for self-attention.

        .. note::
            .. _max_seqlen note:

            When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch.
            :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of
            :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will
            infer them as such.

            When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and
            :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch.
            When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`.
            This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this
            overhead, users are recommended to obtain the maximum sequence lengths from the data loaders
            and pass them in.

            - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch,
              dynamic shapes need to be supported for tensor construction. FlashAttention and
              UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static
              to create graphs before performance heuristics analysis. To reduce the number of graphs created
              per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size,
              :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of
              :attr:`query_layer`, "t" dimension of :attr:`key_layer`}.

7829
7830
7831
7832
7833
7834
7835
7836
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
7837
7838
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
7839
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
7840
7841
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
7842
7843
7844
7845
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
7846
7847
7848
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
7849
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
7850
                   with shape [batch_size + 1] and dtype torch.int32.
7851
                   See :ref:`note<cu_seqlens note>` for more details.
7852
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
7853
7854
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
7855
                   See :ref:`note<cu_seqlens note>` for more details.
7856
7857
7858
7859
7860
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
7861
                   See :ref:`note<cu_seqlens note>` for more details.
7862
7863
7864
7865
7866
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
7867
                   See :ref:`note<cu_seqlens note>` for more details.
7868
7869
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
7870
                      See :ref:`note<max_seqlen note>` for more details.
7871
7872
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
7873
                       See :ref:`note<max_seqlen note>` for more details.
7874
7875
7876
7877
7878
7879
7880
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
7881
        window_size: Optional[Tuple[int, int]], default = `None`
7882
                    Sliding window size for local attention.
7883
7884
7885
7886
7887
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
7888
        core_attention_bias_type: str, default = `no_bias`
7889
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
7890
        core_attention_bias: Optional[torch.Tensor], default = `None`
7891
7892
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
7893
7894
7895
7896
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
7897
        fast_zero_fill: bool, default = `True`
7898
                    Whether to use the fast path to set output tensors to 0 or not.
7899
7900
7901
7902
7903
7904
7905
7906
7907
7908
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
7909
7910
7911
7912
7913
7914
7915
7916
7917
7918
7919
7920
7921
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
7922
        """
7923
7924
7925
7926
7927
7928
7929
7930
7931
7932
7933
        with self.prepare_forward(
            query_layer,
            is_first_microbatch,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:

            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
7934
                        self.logger.warning(
7935
7936
7937
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
7938
7939
7940
7941
7942
7943
7944
7945
7946
7947
7948

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
7949

7950
7951
7952
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
7953
7954
7955
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
7956
7957
7958
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
7959
7960
7961
7962
7963
7964
7965
7966
            assert (
                key_layer.shape[-1] == self.hidden_size_per_attention_head_k
            ), f"Keys have head_dim = {key_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
            assert (
                value_layer.shape[-1] == self.hidden_size_per_attention_head_v
            ), f"Values have head_dim = {value_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
7967

7968
7969
7970
            if qkv_format is None:
                qkv_format = self.qkv_format

7971
7972
7973
7974
7975
7976
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
7977
            assert (
7978
7979
7980
7981
7982
7983
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
7984

7985
7986
7987
7988
            if window_size is None:
                window_size = self.window_size
            window_size = check_set_window_size(attn_mask_type, window_size)

7989
7990
7991
7992
7993
7994
7995
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
7996

7997
7998
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
7999

8000
8001
8002
8003
8004
                # convert causal to causal_bottom_right in inference when KV-caching is in use
                # so users can run with the same attn_mask_type for training and inference
                if attn_mask_type in ["causal", "padding_causal"]:
                    attn_mask_type = attn_mask_type + "_bottom_right"

8005
8006
8007
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
8008

8009
8010
8011
8012
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
8013

8014
8015
8016
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
8017

8018
8019
8020
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
8021

8022
8023
8024
8025
8026
8027
8028
8029
8030
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
8031

8032
8033
8034
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
8035

8036
8037
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
8038
8039

            assert (
8040
8041
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
8042
8043
8044
8045
            ), (
                "Keys and values must have num_gqa_group ="
                f" {self.num_gqa_groups_per_partition} heads!"
            )
8046
8047
8048
8049
8050
8051
8052
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
8053
                assert all(
8054
8055
8056
8057
8058
8059
8060
8061
8062
8063
8064
8065
8066
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
8067
                batch_size = len(cu_seqlens_q) - 1
8068
                if max_seqlen_q is None:
8069
8070
8071
8072
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
8073
                    max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
8074
                if max_seqlen_kv is None:
8075
8076
8077
8078
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
8079
                    max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
8080

8081
8082
8083
8084
8085
8086
            cp_size = 1
            if isinstance(self.cp_group, dist_group_type):
                cp_size = get_distributed_world_size(self.cp_group)
            elif isinstance(self.cp_group, list):
                for group in self.cp_group:
                    cp_size *= get_distributed_world_size(group)
8087
8088
            context_parallel = cp_size > 1

8089
            if qkv_format in ["sbhd", "bshd"]:
8090
                assert all(
8091
8092
8093
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
8094
8095
                    max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
8096
                    batch_size = query_layer.shape[1]
8097
                else:
8098
8099
                    max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
8100
                    batch_size = query_layer.shape[0]
8101
8102
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
8103
8104
8105
8106
8107
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
8108
                        the sequence dimension in 'query_layer'!"""
8109
8110
8111
8112
8113
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
8114
                        the sequence dimension in 'key_layer' and 'value_layer'!"""
8115
8116
8117
8118
8119
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
8120
                        if self.attention_type == "self":
8121
8122
8123
8124
8125
8126
8127
8128
8129
8130
8131
8132
8133
8134
8135
8136
                            cu_seqlens_q = get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                    else:
                        cu_seqlens_q = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
                        cu_seqlens_kv = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
8137

8138
8139
8140
8141
8142
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
8143
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
8144
8145
8146
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
8147
                qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
8148
8149
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
8150

8151
8152
8153
8154
8155
8156
8157
8158
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
8159
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
8160
8161
8162
8163
8164
8165
8166
8167
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
8168
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
8169
8170
8171
8172
8173
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

8174
8175
            core_attention_bias_shape = None
            if core_attention_bias is not None:
8176
                if (
8177
8178
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
8179
                ):
8180
8181
8182
8183
8184
8185
8186
8187
8188
8189
8190
8191
8192
8193
8194
8195
8196
8197
8198
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

            pad_between_seqs = (
                cu_seqlens_q_padded is not None
8199
                and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
8200
8201
            ) or (
                cu_seqlens_kv_padded is not None
8202
                and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
8203
            )
8204

8205
            attention_params = AttentionParams(
8206
8207
8208
8209
8210
8211
8212
8213
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
8214
8215
                head_dim_qk=query_layer.shape[-1],
                head_dim_v=value_layer.shape[-1],
8216
8217
8218
8219
8220
8221
8222
8223
8224
8225
8226
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
8227
8228
                deterministic=self.deterministic,
                is_training=self.training,
8229
8230
8231
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
8232
            global _attention_backends, _use_flash_attn_3
8233
8234
8235
8236
8237
8238
8239
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
8240
                _use_flash_attn_3 = _flash_attn_3_is_installed
8241
8242
8243
8244
8245
8246
8247
8248
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
                ) = get_attention_backend(attention_params)
                if use_flash_attention:
8249
8250
                    self.logger.info(
                        "Running with FlashAttention backend (version %s)",
8251
                        _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version,
8252
                    )
8253
8254
8255
8256
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
8257
                    )
8258
8259
8260
8261
8262
8263
8264
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
8265

8266
8267
8268
8269
8270
8271
8272
8273
8274
8275
8276
8277
8278
8279
8280
8281
8282
8283
8284
8285
8286
8287
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
8288
                    cp_comm_type=self.cp_comm_type,
8289
8290
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
8291
8292
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
8293
                )
8294

8295
            if use_fused_attention:
8296
8297
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
8298
8299
8300
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
8301
8302
8303
8304
8305
8306
8307
                    fu_core_attention_bias_type = "post_scale_bias"
                    _, fu_core_attention_bias = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
8308
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
8309
                    )
8310
8311
8312
8313
8314
8315
8316
8317
8318
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
8319
8320
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
8321
8322
8323
8324
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
8325
                        window_size=window_size,
8326
8327
8328
8329
8330
8331
8332
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
8333
                        cp_comm_type=self.cp_comm_type,
8334
8335
8336
8337
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
                    )
                return self.fused_attention(
8338
8339
8340
8341
8342
8343
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
8344
8345
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
8346
8347
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
8348
8349
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
8350
                    window_size=window_size,
8351
                    fused_attention_backend=fused_attention_backend,
8352
8353
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
8354
8355
8356
8357
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
8358
                    cp_comm_type=self.cp_comm_type,
8359
8360
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
8361
                )
8362

8363
            from .cpu_offload import CPUOffloadEnabled
8364

8365
8366
8367
8368
8369
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
8370

8371
8372
8373
8374
8375
8376
8377
8378
8379
8380
8381
8382
            if use_unfused_attention:
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
8383
                        window_size=window_size,
8384
8385
8386
8387
8388
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
8389
8390
8391
                    query_layer,
                    key_layer,
                    value_layer,
8392
8393
8394
8395
8396
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
8397
                    window_size=window_size,
8398
8399
8400
8401
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
8402

8403
            raise ValueError("No dot product attention support for the provided inputs!")
8404
8405


8406
8407
8408
8409
8410
8411
8412
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

8413
8414
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
8415

8416
8417
8418
8419
8420
8421
8422
8423
8424
8425
8426
8427
8428
8429
8430
8431
8432
8433
8434
8435
8436
8437
8438
8439
8440
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
8441
8442
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
8443
                   default = `causal`
8444
8445
8446
8447
8448
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
8449
8450
8451
8452
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
8453
8454
8455
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
8456
                be overridden by :attr:`window_size` in `forward` as well.
8457
8458
8459
8460
8461
8462
8463
8464
8465
8466
8467
8468
8469
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
8470
8471
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
8472
8473
8474
8475
8476
8477
8478
8479
8480
8481
8482
8483
8484
8485
8486
8487
8488
8489
8490
8491
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
8492
          The device on which the parameters of the model will be allocated. It is the user's
8493
8494
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
8495
8496
8497
8498
8499
8500
8501
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
8502
            For that, please use `get_qkv_layout` to gain the layout information.
8503
8504
8505
8506
8507
8508
8509
8510
8511
8512
8513
8514
8515
8516
8517
8518
8519
8520
8521
8522
8523
8524
8525
8526
8527
8528
8529
8530
8531
8532
8533
8534
8535
8536
8537
8538
8539
8540
8541
8542

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
8543
8544
8545
8546
8547
8548
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
8549
8550
8551
8552
8553
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
8554
        layer_number: Optional[int] = None,
8555
        attn_mask_type: str = "causal",
8556
        window_size: Optional[Tuple[int, int]] = None,
8557
8558
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
8559
        num_gqa_groups: Optional[int] = None,
8560
8561
8562
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
8563
        params_dtype: Optional[torch.dtype] = None,
8564
        return_bias: bool = False,
8565
8566
8567
8568
8569
8570
8571
8572
8573
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
8574
        ub_overlap_rs_dgrad: bool = False,
8575
8576
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
8577
        bias: bool = True,
8578
        normalization: str = "LayerNorm",
8579
        device: Union[torch.device, str] = "cuda",
8580
        qkv_format: str = "sbhd",
8581
8582
    ) -> None:
        super().__init__()
8583

8584
        self.qkv_format = qkv_format
8585
        self.attn_mask_type = attn_mask_type
8586
        self.window_size = check_set_window_size(attn_mask_type, window_size)
8587
        self.layer_number = layer_number
8588
8589
8590
8591
8592
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
8593
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
8594
        self.num_attention_heads = num_attention_heads
8595
        self.return_bias = return_bias
8596
8597
        self.cp_size = 1
        self.cp_rank = 0
8598
8599
8600
8601
8602
8603
8604

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
8605
8606
8607
8608
8609

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

8610
8611
8612
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
8613
8614
8615
8616
8617
8618

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
8619
8620
8621
8622
8623
8624
8625
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
8626
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
8627
8628
8629
8630

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
8631
8632
8633
8634
8635
8636
8637

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
8638
            "params_dtype": self.params_dtype,
8639
            "device": device,
8640
8641
8642
8643
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
8644
        if self.attention_type == "self":
8645
8646
            parameters_split = None
            if not fuse_qkv_params:
8647
8648
8649
8650
8651
8652
8653
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
8654
8655
8656
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
8657
                    self.hidden_size_q + 2 * self.hidden_size_kv,
8658
8659
8660
8661
8662
8663
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
8664
                    parameters_split=parameters_split,
8665
8666
8667
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
8668
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
8669
                    ub_overlap_ag=ub_overlap_ag,
8670
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8671
                    ub_name="qkv",
8672
8673
8674
8675
8676
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
8677
                    self.hidden_size_q + 2 * self.hidden_size_kv,
8678
8679
8680
8681
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
8682
                    parameters_split=parameters_split,
8683
8684
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
8685
        elif self.attention_type == "cross":
8686
8687
8688
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
8689
                    self.hidden_size_q,
8690
8691
8692
8693
8694
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
8695
                    parameters_split=("query",) if not fuse_qkv_params else None,
8696
8697
8698
8699
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
8700
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
8701
                    ub_overlap_ag=ub_overlap_ag,
8702
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8703
                    ub_name="qkv",
8704
8705
8706
8707
8708
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
8709
                    self.hidden_size_q,
8710
8711
8712
8713
8714
8715
8716
8717
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
8718
                2 * self.hidden_size_kv,
8719
8720
8721
8722
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
8723
                parameters_split=("key", "value") if not fuse_qkv_params else None,
8724
8725
8726
8727
8728
8729
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
8730
            self.hidden_size_per_attention_head,
8731
8732
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
8733
            qkv_format=self.qkv_format,
8734
8735
8736
8737
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
8738
            layer_number=self.layer_number,
8739
            attention_type=self.attention_type,
8740
8741
8742
8743
        )

        # Linear
        self.proj = Linear(
8744
            self.hidden_size_q,
8745
8746
8747
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
8748
            return_bias=return_bias,
8749
            parallel_mode="row" if set_parallel_mode else None,
8750
8751
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8752
            ub_name="proj",
8753
8754
8755
8756
            **common_gemm_kwargs,
        )

    def _allocate_memory(
8757
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
8758
    ) -> torch.Tensor:
8759
        """Allocates memory for KV cache."""
8760
8761
8762
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
8763
            self.num_gqa_groups_per_partition,
8764
            self.hidden_size_per_attention_head,
8765
            dtype=dtype,
8766
8767
8768
8769
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
8770
8771
8772
8773
8774
8775
8776
8777
8778
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
8779
8780
        self.tp_group = tp_group

8781
    def set_context_parallel_group(
8782
        self,
8783
        cp_group: Union[dist_group_type, List[dist_group_type], None],
8784
        cp_global_ranks: List[int],
8785
        cp_stream: torch.cuda.Stream,
8786
        cp_comm_type: str = "p2p",
8787
    ) -> None:
8788
8789
8790
8791
8792
8793
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
8794
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
8795
                  context parallel process group.
8796
8797
8798
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
8799
8800
8801
8802
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
8803
        cp_comm_type : str, default = `p2p`
8804
                      inter-gpu communication type for context parallelism.
8805
                      Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
8806
8807
8808
8809
8810
8811
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
8812
8813
8814
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
8815
        """
8816
8817
8818
8819
8820
8821
8822
8823
8824
8825
8826
8827
8828
8829
8830
        if isinstance(cp_group, dist_group_type):
            self.cp_size = get_distributed_world_size(cp_group)
            self.cp_rank = get_distributed_rank(cp_group)
        elif isinstance(cp_group, list):
            assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
            assert (
                cp_comm_type == "a2a+p2p"
            ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
            cp_size_a2a = get_distributed_world_size(cp_group[0])
            cp_rank_a2a = get_distributed_rank(cp_group[0])
            cp_size_p2p = get_distributed_world_size(cp_group[1])
            cp_rank_p2p = get_distributed_rank(cp_group[1])
            self.cp_size = cp_size_a2a * cp_size_p2p
            self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a

8831
8832
8833
8834
8835
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
8836
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
8837

8838
8839
8840
    def forward(
        self,
        hidden_states: torch.Tensor,
8841
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8842
        encoder_output: Optional[torch.Tensor] = None,
8843
        attn_mask_type: Optional[str] = None,
8844
        window_size: Optional[Tuple[int, int]] = None,
8845
8846
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
8847
        inference_params: Optional[InferenceParams] = None,
8848
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8849
8850
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
8851
        alibi_slopes: Optional[torch.Tensor] = None,
8852
8853
8854
8855
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
8856
        fast_zero_fill: bool = True,
8857
    ) -> Tuple[Union[torch.Tensor, None], ...]:
8858
8859
8860
8861
8862
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

8863
8864
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
8865
8866
8867
8868
8869

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
8870
8871
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
8872
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
8873
8874
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
8875
8876
8877
8878
8879
8880
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
8881
                       default = `None`
8882
8883
8884
8885
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
8886
8887
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
8888
8889
8890
8891
8892
8893
8894
8895
8896
8897
8898
8899
8900
8901
8902
8903
8904
8905
8906
8907
8908
8909
8910
8911
8912
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
8913
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
8914
        core_attention_bias: Optional[torch.Tensor], default = `None`
8915
8916
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
8917
8918
8919
8920
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
8921
8922
8923
8924
8925
8926
8927
8928
8929
8930
8931
8932
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
8933
8934
8935
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
8936
8937
        # hidden_states: [sq, b, h]

8938
        if attn_mask_type is None:
8939
            attn_mask_type = self.attn_mask_type
8940
8941
        if window_size is None:
            window_size = self.window_size
8942
        window_size = check_set_window_size(attn_mask_type, window_size)
8943

8944
        if "padding" in attn_mask_type and attention_mask is not None:
8945
8946
            for mask in attention_mask:
                assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
8947

8948
8949
8950
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
8951

8952
        # =================================================
8953
        # Pre-allocate memory for key-values for inference
8954
8955
8956
        # =================================================

        if inference_params and self.layer_number is not None:
8957
8958
8959
            assert (
                self.qkv_format != "thd"
            ), "qkv_format == thd is not supported for an inference with KV-cache!"
8960
            if self.layer_number not in inference_params.key_value_memory_dict:
8961
                inf_max_seq_len = inference_params.max_sequence_length
8962
8963
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
8964
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8965
8966
                )
                inference_value_memory = self._allocate_memory(
8967
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8968
8969
8970
8971
8972
8973
8974
8975
8976
8977
8978
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

8979
        # ======================
8980
        # Query, Key, and Value
8981
        # ======================
8982

8983
8984
8985
8986
8987
        fp8_mha = (
            FP8GlobalStateManager.is_fp8_enabled()
            and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
        )

8988
        layernorm_output = None
cyanguwa's avatar
cyanguwa committed
8989
8990
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
8991
8992
8993
8994
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8995
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8996
8997
8998
8999
9000
9001
9002
9003
9004
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
9005
                    fp8_output=fp8_mha and rotary_pos_emb is None,
9006
9007
                )

9008
9009
9010
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
9011
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
9012
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
9013
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
9014
9015
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
9016
9017
9018
9019
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
9020
9021
9022
9023
9024
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
9025
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
9026
9027
9028
                )
                # split along third last dimension
                split_dim = -3
9029
9030
9031

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
9032
9033
9034
9035
9036
9037
9038
9039
9040
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
9041
                )
9042
            else:
cyanguwa's avatar
cyanguwa committed
9043
                query_layer, key_layer, value_layer = torch.split(
9044
9045
9046
9047
                    mixed_x_layer,
                    (num_queries_per_key_value, 1, 1),
                    dim=split_dim,
                )
cyanguwa's avatar
cyanguwa committed
9048

9049
9050
9051
9052
9053
9054
9055
9056
9057
9058
9059
9060
            if self.qkv_format == "thd":
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
            else:
                # query: -> [sq, b, np, hn]
                # key, value: -> [sq, b, ng, hn]
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
cyanguwa's avatar
cyanguwa committed
9061
9062
        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
9063
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
9064
                encoder_output,
9065
                is_first_microbatch=is_first_microbatch,
9066
                fp8_output=fp8_mha and rotary_pos_emb is None,
9067
9068
9069
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
9070
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
9071
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
9072
                    self.num_gqa_groups_per_partition,
9073
9074
9075
9076
9077
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
9078
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
9079
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
9080
                    2 * self.num_gqa_groups_per_partition,
9081
9082
9083
9084
9085
9086
9087
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
9088
9089
9090
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
9091
9092
9093
                    mixed_kv_layer,
                    split_dim,
                    mixed_kv_layer.shape[split_dim] // 2,
cyanguwa's avatar
cyanguwa committed
9094
                )
9095
            else:
cyanguwa's avatar
cyanguwa committed
9096
                key_layer, value_layer = torch.split(
9097
9098
9099
                    mixed_kv_layer,
                    mixed_kv_layer.shape[split_dim] // 2,
                    dim=split_dim,
cyanguwa's avatar
cyanguwa committed
9100
                )
9101
9102
9103
9104
9105
9106
9107
9108
9109
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
9110
9111
9112
9113
9114
9115

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
9116
                    fp8_output=fp8_mha and rotary_pos_emb is None,
9117
9118
9119
9120
9121
9122
9123
9124
9125
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
9126
                    fp8_output=fp8_mha and rotary_pos_emb is None,
9127
9128
9129
9130
9131
9132
9133
9134
9135
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

9136
9137
9138
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
9139

9140
        if rotary_pos_emb is not None:
9141
9142
9143
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
9144
            # duplicate the pos_emb for self attention
9145
            if not isinstance(rotary_pos_emb, tuple):
9146
                rotary_pos_emb = (rotary_pos_emb,) * 2
9147
9148

            q_pos_emb, k_pos_emb = rotary_pos_emb
9149
9150
9151
9152
9153
9154
9155

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)
9156
9157
                else:
                    raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.")
9158
9159
9160
9161
9162
9163
9164

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

9165
9166
9167
9168
9169
9170
9171
9172
9173
9174
9175
9176
9177
9178
9179
9180
9181
9182
            query_layer = apply_rotary_pos_emb(
                query_layer,
                q_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_q,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
            key_layer = apply_rotary_pos_emb(
                key_layer,
                k_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_kv,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
9183

9184
9185
9186
9187
        # ===========================
        # Core attention computation
        # ===========================

9188
9189
9190
9191
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
9192
            qkv_format=self.qkv_format,
9193
9194
9195
9196
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
9197
9198
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
9199
            window_size=window_size,
9200
9201
9202
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
9203
            alibi_slopes=alibi_slopes,
9204
            fast_zero_fill=fast_zero_fill,
9205
            inference_params=inference_params,
9206
9207
        )

9208
        # ===================
9209
        # Output. [sq, b, h]
9210
        # ===================
9211

9212
        projection_output = self.proj(
9213
9214
            context_layer,
            is_first_microbatch=is_first_microbatch,
9215
9216
        )

9217
9218
9219
9220
9221
9222
9223
9224
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
9225
        if self.input_layernorm and self.return_layernorm_output:
9226
9227
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]