attention.py 378 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
from importlib.metadata import PackageNotFoundError
10
import math
11
import os
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
import warnings
14
import logging
15
import functools
16

17
from dataclasses import dataclass, fields
cyanguwa's avatar
cyanguwa committed
18
import numpy as np
19
from packaging.version import Version as PkgVersion
20
21

import torch
22
import torch.nn.functional as F
23

24
import transformer_engine_torch as tex
25
import transformer_engine as te
26
27
28
29
30
from transformer_engine.pytorch.utils import (
    get_cudnn_version,
    nvtx_range_pop,
    nvtx_range_push,
)
31
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
32
33
    fused_attn_fwd,
    fused_attn_bwd,
34
35
36
37
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
38
39
40
41
42
43
44
45
46
47
48
49
50
    META_QKV,
    META_DQKV,
    META_O,
    META_DO,
    META_S,
    META_DP,
    META_O_CP,
    META_DQKV_CP,
)
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    get_fp8_te_dtype,
    get_fp8_torch_dtype,
51
)
52
from transformer_engine.pytorch.float8_tensor import Float8Tensor
53
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
54
from transformer_engine.pytorch.module import LayerNormLinear, Linear
55
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
56
57
58
59
60
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
61
    get_default_init_method,
62
63
64
65
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
66
    AttnBiasTypes,
67
    QKVLayouts,
68
    dist_group_type,
69
    TE_DType,
70
71
72
73
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
74
    get_distributed_rank,
75
    checkpoint,
76
77
78
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
79
80
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
81
)
82
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
83
from transformer_engine.pytorch.graph import is_graph_capturing
84
85
86
87
88
from transformer_engine.pytorch.tensor.quantized_tensor import (
    QuantizedTensor,
    prepare_for_saving,
    restore_from_saved,
)
89

90

91
92
93
94
95
96
97
98
99
100
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
101
102
103
104
105
106
107
108
109
110
fa_logger = logging.getLogger()
fa_logger.setLevel(_log_level)
if not fa_logger.hasHandlers():
    fa_logger.addHandler(_stream_handler)


@functools.lru_cache(maxsize=None)
def _get_supported_versions(version_min, version_max):
    return ">= " + str(version_min) + ", " + "<= " + str(version_max)

111

112
113
114
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
115
116
117
118
119

# Detect flash-attn v2 in the environment
_flash_attn_is_installed = False
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
120
121
_flash_attn_version_required_blackwell = PkgVersion("2.7.3")
_flash_attn_max_version = PkgVersion("2.7.3")
122
123
124
125
126
127
128
_flash_attn_2_plus = False
_flash_attn_2_1_plus = False
_flash_attn_2_3_plus = False
_flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False
129
_flash_attn_2_7_0_plus = False
130

131
flash_attn_cuda_bwd = None
132
133
flash_attn_func = None
flash_attn_varlen_func = None
134
135
136
137
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
138

139
140
141
try:
    _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
142
    if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
143
144
145
146
147
        fa_logger.debug(
            "flash-attn v2 is not installed. To use, please install it by"
            """ "pip install flash-attn".""",
        )
else:
148
149
150
151
152
153
154
    if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
        if _flash_attn_version_required_blackwell <= _flash_attn_version <= _flash_attn_max_version:
            _flash_attn_is_installed = True
    elif _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
        _flash_attn_is_installed = True

    if _flash_attn_is_installed:
155
        from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
156
        from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
157
158
        from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
        from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd
159
        from flash_attn.flash_attn_interface import (
160
            _flash_attn_varlen_forward as _flash_attn_varlen_fwd,
161
162
        )
        from flash_attn.flash_attn_interface import (
163
            _flash_attn_varlen_backward as _flash_attn_varlen_bwd,
164
165
166
167
168
169
170
171
172
        )

        _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
        _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
        _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
        _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
        _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
        _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
        _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
173
        _flash_attn_2_7_0_plus = _flash_attn_version >= PkgVersion("2.7.0")
174
175
176
    elif (
        torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN
    ):
177
178
179
        fa_logger.warning(
            "Supported flash-attn versions are %s. Found flash-attn %s.",
            _get_supported_versions(
180
181
182
183
184
                (
                    _flash_attn_version_required
                    if get_device_compute_capability() < (10, 0)
                    else _flash_attn_version_required_blackwell
                ),
185
186
187
188
189
190
191
192
193
194
195
                _flash_attn_max_version,
            ),
            _flash_attn_version,
        )

# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
_flash_attn_3_is_installed = False
_flash_attn_3_version = PkgVersion("0")
_flash_attn_3_0_0_beta = False
196
_use_flash_attn_3 = False
197
198
# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved
# https://github.com/Dao-AILab/flash-attention/issues/1452
199
_flash_attn_3_installation_steps = """\
200
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
201
202
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
203
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py"""
204
try:
205
    _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
206
except PackageNotFoundError:
207
    if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
208
209
        fa_logger.debug(
            "flash-attn v3 is not installed. To use, please install it by \n%s",
210
            _flash_attn_3_installation_steps,
211
        )
212
213
214
215
216
else:
    from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
    from flashattn_hopper.flash_attn_interface import (
        flash_attn_varlen_func as flash_attn_varlen_func_v3,
    )
217
218
    from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
    from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
219
    from flashattn_hopper.flash_attn_interface import (
220
        _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
221
222
    )
    from flashattn_hopper.flash_attn_interface import (
223
        _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
224
    )
225

226
227
    _flash_attn_3_is_installed = True
    _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
228
    _use_flash_attn_3 = True
229

230
231
232
233
234
235
236
_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
237
}
238
239


240
241
@dataclass(eq=True)
class AttentionParams:
242
    """
243
    Attention parameters used to determine which backend to be used.
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    Parameters
    ----------
    qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
        Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
    qkv_dtype: torch.dtype, default = `torch.bfloat16`
        Data type of query/key/value tensors.
    qkv_layout: str, default = "sbh3d"
        Query/key/value tensor memory layout.
    batch_size: int, default = 1
        Batch size.
    num_heads: int, default = 16
        Number of attention heads in the query tensor.
    num_gqa_groups: int, default = 16
        Number of attention heads in key and value tensors.
    max_seqlen_q: int, default = 128
        Maximum sequence length of the query tensor.
    max_seqlen_kv: int, default = 128
        Maximum sequence length of the key and value tensors.
263
264
265
266
    head_dim_qk: int, default = 64
        The size of each attention head in query and key tensors.
    head_dim_v: int, default = 64
        The size of each attention head in the value tensor.
267
268
269
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
        `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
270
    window_size: Tuple[int, int], default = None
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        Sliding window attention size.
    alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
        Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
    core_attention_bias_type: str, default = `no_bias`
        Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
    core_attention_bias_shape: str, default = `1hss`
        Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
    core_attention_bias_requires_grad: bool, default = `True`
        Whether attention bias requires gradient.
    pad_between_seqs: bool, default = `False`
        Whether there is padding between sequences in a batch.
        This only applies to `qkv_format=thd`.
    attention_dropout: float, default = 0.0
        Attention dropout.
    context_parallel: bool, default = `False`
        Whether context parallelism is used or not.
    deterministic: bool, default = `False`
        Whether to run `DotProductAttention` with determinism or not.
289
290
    is_training: bool, default = `True`
        Whether in training mode (`True`) or inference mode (`False`)
291
292
293
294
    fp8: bool, default = `False`
        Whether `DotProductAttention` is in an `fp8_autocast` region.
    fp8_meta: Optional[Dict[str Any]], default = `None`
        The FP8 metadata tensor of `DotProductAttention`.
295
296
297
298
299
300
301
302
303
304
    """

    qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
    qkv_dtype: torch.dtype = torch.bfloat16
    qkv_layout: str = "sbh3d"
    batch_size: int = 1
    num_heads: int = 16
    num_gqa_groups: int = 16
    max_seqlen_q: int = 128
    max_seqlen_kv: int = 128
305
306
    head_dim_qk: int = 64
    head_dim_v: int = 64
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    attn_mask_type: str = "no_mask"
    window_size: Union[Tuple[int, int], None] = None
    alibi_slopes_shape: Union[torch.Size, List, None] = None
    core_attention_bias_type: str = "no_bias"
    core_attention_bias_shape: str = "1hss"
    core_attention_bias_requires_grad: bool = True
    pad_between_seqs: bool = False
    attention_dropout: float = 0.0
    context_parallel: bool = False
    deterministic: bool = False
    is_training: bool = True
    fp8: bool = False
    fp8_meta: Union[Dict[str, Any], None] = None

321
322
323
324
325
326
327
328
329
330
331
332
333
334
    def __eq__(self, other):
        """
        Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared,
        since all other entries of fp8_meta are unused in get_attention_backend.
        """
        if not isinstance(other, self.__class__):
            return NotImplemented
        for field in fields(self):
            fname = field.name
            sf = getattr(self, fname)
            of = getattr(other, fname)
            if fname != "fp8_meta":
                if sf != of:
                    return False
335
            elif sf.get("recipe", None) != of.get("recipe", None):
336
337
338
                return False
        return True

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


355
356
357
358
359
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
    """Make tensor contiguous if final stride is not 1."""
    return tensor.contiguous() if tensor.stride(-1) != 1 else tensor


360
361
362
363
364
365
366
367
368
def get_attention_backend(
    attention_params: AttentionParams = None,
):
    """
    Select the appropriate attention backend/sub-backend based on user input and runtime environment.

    Parameters
    ----------
    See `AttentionParams`.
369
370
371
372
373
374
375

    Returns
    ----------
    use_flash_attention: bool
        Whether the `FlashAttention` backend has been selected.
    use_fused_attention: bool
        Whether the `FusedAttention` backend has been selected.
376
377
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend
        If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
378
379
380
381
382
383
    use_unfused_attention: bool
        Whether the `UnfusedDotProductAttention` backend has been selected.
    available_backends: List[bool]
        All available backends that could support the provided input. A list of Booleans
        in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
    """
384
385
386
387
388
389
390
391
    qkv_type = attention_params.qkv_type
    qkv_dtype = attention_params.qkv_dtype
    qkv_layout = attention_params.qkv_layout
    batch_size = attention_params.batch_size
    num_heads = attention_params.num_heads
    num_gqa_groups = attention_params.num_gqa_groups
    max_seqlen_q = attention_params.max_seqlen_q
    max_seqlen_kv = attention_params.max_seqlen_kv
392
393
    head_dim_qk = attention_params.head_dim_qk
    head_dim_v = attention_params.head_dim_v
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    attn_mask_type = attention_params.attn_mask_type
    window_size = attention_params.window_size
    alibi_slopes_shape = attention_params.alibi_slopes_shape
    core_attention_bias_type = attention_params.core_attention_bias_type
    core_attention_bias_shape = attention_params.core_attention_bias_shape
    core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
    pad_between_seqs = attention_params.pad_between_seqs
    attention_dropout = attention_params.attention_dropout
    context_parallel = attention_params.context_parallel
    deterministic = attention_params.deterministic
    is_training = attention_params.is_training
    fp8 = attention_params.fp8
    fp8_meta = attention_params.fp8_meta

    # Run config
409
    logger = logging.getLogger("DotProductAttention")
410
411
412
    logger.setLevel(_log_level)
    if not logger.hasHandlers():
        logger.addHandler(_stream_handler)
413
414
415
416
417
    device_compute_capability = get_device_compute_capability()
    cudnn_version = get_cudnn_version()
    run_config = {
        "transformer_engine_version": te.__version__,
        "compute_capability": "sm"
418
        + str(10 * device_compute_capability[0] + device_compute_capability[1]),
419
420
421
422
423
424
        "flash_attn_version": (
            str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
        ),
        "flash_attn_3_version": (
            str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed"
        ),
425
426
427
428
429
430
431
432
433
        "cudnn_version": ".".join([str(i) for i in cudnn_version]),
    }
    attention_params_dict = {
        field.name: getattr(attention_params, field.name) for field in fields(attention_params)
    }
    run_config.update(attention_params_dict)
    if fp8:
        run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
    logger.debug("Running with config=%s", run_config)
434

435
436
437
438
439
440
    # The following sections check if `FlashAttention` supports the provided attention params,
    # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
    # necessary for performance/functionality, a warning will be issued to prompt users to
    # install an appropriate FA version.
    global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3

441
    # Filter: Environment variables
442
443
444
445
    use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
    use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
    use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
    if not use_flash_attention and _flash_attn_is_installed:
446
447
448
449
450
451
452
453
        logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
    if not use_fused_attention:
        logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
    if not use_unfused_attention:
        logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")

    # Filter: Compute capability
    if device_compute_capability < (8, 0):
454
        if use_flash_attention and _flash_attn_is_installed:
455
            logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
456
        use_flash_attention = False
457
458
459
        if use_fused_attention:
            logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
            use_fused_attention = False
460
    if device_compute_capability < (9, 0):
461
        if use_flash_attention and _flash_attn_3_is_installed:
462
            logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
463
        _use_flash_attn_3 = False
464
465

    # Filter: Data type
466
467
468
469
    if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
        torch.Tensor,
        Float8Tensor,
    ]:
470
        if use_flash_attention and _flash_attn_is_installed:
471
472
473
474
475
476
            logger.debug(
                "Disabling FlashAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
477
        use_flash_attention = False
478
479
480
481
482
483
484
485
        if use_fused_attention:
            logger.debug(
                "Disabling FusedAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
            use_fused_attention = False
486
487
488

    # Filter: Execution type
    if fp8 and fp8_meta["recipe"].fp8_dpa:
489
        if use_flash_attention and not _use_flash_attn_3:
490
491
            if _flash_attn_is_installed:
                logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
492
493
494
495
496
            use_flash_attention = False
        if use_flash_attention and _use_flash_attn_3 and is_training:
            logger.debug(
                "Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
            )
497
498
499
500
501
502
            use_flash_attention = False
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
            use_unfused_attention = False

    # Filter: Head dimension
503
    if use_flash_attention and head_dim_qk != head_dim_v:
504
505
        if _flash_attn_is_installed:
            logger.debug("Disabling FlashAttention as it does not support MLA.")
506
        use_flash_attention = False
507
    if use_flash_attention and (
508
509
510
        head_dim_qk > 256
        or head_dim_qk % 8 != 0
        or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
511
    ):
512
513
514
515
516
517
518
519
520
521
        if _flash_attn_is_installed:
            logger.debug(
                "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
                "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
                "head_dim_qk <= 256 (>192 requires sm80/90). "
                "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
                head_dim_qk,
                head_dim_v,
                ".".join([str(i) for i in device_compute_capability]),
            )
522
        use_flash_attention = False
523
524
525
526
527
528
529
    qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
    if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
        logger.debug(
            "Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
            qkv_layout,
        )
        use_fused_attention = False
530
531
532
533
534
535
536
537

    # Filter: QKV layout
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
            use_unfused_attention = False
        if use_flash_attention and pad_between_seqs:
538
539
540
541
542
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention for qkv_format = thd when there is "
                    "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
                )
543
544
            use_flash_attention = False

545
    # Filter: Dropout
546
547
548
    if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
        logger.debug("Disabling FlashAttention 3 for dropout")
        _use_flash_attn_3 = False
549

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    # Filter: Context parallelism
    # qkv_format | attn_mask_type              | attn_bias_type           | supported backends
    # ----------------------------------------------------------------------------------------------------
    # bshd, sbhd | self-attention:             | no_bias, post_scale_bias | FlashAttention, FusedAttention
    #            |     no_mask, causal         |                          |
    #            | cross-attention:            |                          |
    #            |     no_mask                 |                          |
    # thd        | self-attention:             | no_bias                  | FlashAttention, FusedAttention
    #            |     padding, padding_causal |                          | if no padding between sequences,
    #            | cross-attention:            |                          | FusedAttention
    #            |     padding                 |                          | if there is padding between sequences
    # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
    if context_parallel and use_unfused_attention:
        logger.debug(
            "Disabling UnfusedDotProductAttention as it does not support context parallelism"
        )
        use_unfused_attention = False
    if context_parallel and use_flash_attention:
568
        if fp8 and fp8_meta["recipe"].fp8_dpa:
569
570
571
572
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with FP8"
                )
573
            use_flash_attention = False
574
        if "bottom_right" in attn_mask_type:
575
576
577
578
579
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " causal_bottom_right masking"
                )
580
581
            use_flash_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
582
583
584
585
586
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " causal masking for cross-attention"
                )
587
588
            use_flash_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
589
590
591
592
593
594
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with bias"
                    " type of %s",
                    core_attention_bias_type,
                )
595
596
            use_flash_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
597
598
599
600
601
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " attention bias for THD format"
                )
602
            use_flash_attention = False
603

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    if context_parallel and use_fused_attention:
        if "bottom_right" in attn_mask_type:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with"
                " causal_bottom_right masking"
            )
            use_fused_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with causal"
                " masking for cross-attention"
            )
            use_fused_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with bias type"
                " of %s",
                core_attention_bias_type,
            )
            use_fused_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with attention"
                " bias for THD format"
            )
            use_fused_attention = False
        elif head_dim_qk != head_dim_v:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with MLA"
            )
            use_fused_attention = False

636
    # Filter: Attention mask
637
638
639
640
641
642
643
644
645
646
647
648
649
650
    # attn_mask_type              | attention_mask                       | supported backends
    # ----------------------------------------------------------------------------------------
    # no_mask                     | None                                 | All
    # padding                     |                                      | All
    #     self-attention          | One tensor in shape [b, 1, 1, sq]    |
    #     cross-attention         | Tuple of two tensors in shapes       |
    #                             | [b, 1, 1, sq] and [b, 1, 1, skv]     |
    # causal                      | None                                 |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # padding_causal              | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # causal_bottom_right         | None                                 | All
651
    # padding_causal_bottom_right | Same as "padding"                    | All
652
653
    # arbitrary                   | One tensor in shape broadcastable to | UnfusedDotProductAttention
    #                             | [b, h, sq, skv]                      |
654
    if attn_mask_type == "arbitrary":
655
        if use_flash_attention and _flash_attn_is_installed:
656
657
658
659
660
            logger.debug("Disabling FlashAttention for arbitrary mask")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention for arbitrary mask")
        use_fused_attention = False
661
662
    if (
        use_flash_attention
663
        and _use_flash_attn_3
664
665
666
667
668
669
670
671
672
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        _use_flash_attn_3 = False
673
674
675
676
677
    if (
        use_flash_attention
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
678
679
680
681
682
683
684
685
686
        if _flash_attn_2_1_plus:
            logger.warning(
                "Disabling FlashAttention as it only supports bottom-right-diagonal "
                "causal mask since flash-attn 2.1. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False
        if not _flash_attn_is_installed:
            _flash_attn_max_version = PkgVersion("2.1")
687
688
689
690
691
    if (
        use_flash_attention
        and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
        and max_seqlen_q != max_seqlen_kv
    ):
692
693
694
695
696
697
698
699
700
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.1")
        elif not _flash_attn_2_1_plus and not _use_flash_attn_3:
            logger.warning(
                "Disabling FlashAttention as it only supports top-left-diagonal "
                "causal mask before flash-attn 2.1. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False
701
702
703
704
705
706
707
708
709
    if (
        use_flash_attention
        and _use_flash_attn_3
        and fp8
        and fp8_meta["recipe"].fp8_dpa
        and "padding" in attn_mask_type
    ):
        logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
        _use_flash_attn_3 = False
710
711

    # Filter: Sliding window attention
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
    #    backend                 |      window_size       | diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | (-1, -1) or (>=0, >=0) | bottom right
    # FusedAttention             | (-1,  0) or (>=0, 0)   | top left
    # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
    #                            |                        | converts window_size to an 'arbitrary' mask
    if window_size is None:
        window_size = check_set_window_size(attn_mask_type, window_size)
    else:
        if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention"
                    " for FP8"
                )
                use_fused_attention = False
728
            elif window_size[1] != 0 or attention_dropout != 0.0:
729
730
                logger.debug(
                    "Disabling FusedAttention as it only supports sliding window attention "
731
                    "with (left, 0) and no dropout"
732
733
                )
                use_fused_attention = False
734
            elif max_seqlen_q > max_seqlen_kv:
735
736
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
737
                    "with s_q > s_kv for cross-attention"
738
739
                )
                use_fused_attention = False
740
741
742
743
744
745
746
747
748
749
750
751
752
        if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if _use_flash_attn_3:
                logger.debug(
                    "Disabling FlashAttention 3 as it does not support sliding window attention"
                )
                _use_flash_attn_3 = False
            if not _flash_attn_is_installed:
                _flash_attn_version_required = PkgVersion("2.3")
            elif not _flash_attn_2_3_plus:
                logger.debug(
                    "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
                )
                use_flash_attention = False
753
754

    # Filter: Attention bias
755
756
757
758
759
760
761
762
    #    backend                 |      bias types              | ALiBi diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | no_bias, alibi/alibi_slopes  | bottom right
    # FusedAttention             | no_bias, post_scale_bias     |
    #                            | alibi/alibi_slopes           | top left,
    #                            |                              | bottom_right (converts to a 'post_scale_bias' bias)
    # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
    #                            | alibi/alibi_slopes           | both; converts to a 'post_scale_bias' bias
763
    if use_flash_attention and core_attention_bias_type == "alibi":
764
        if _use_flash_attn_3:
765
766
            logger.debug("Disabling FlashAttention 3 for ALiBi")
            _use_flash_attn_3 = False
767
768
769
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.4")
        elif not _flash_attn_2_4_plus:
770
771
            logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
            use_flash_attention = False
772

773
774
775
776
    if use_flash_attention and (
        core_attention_bias_type not in ["no_bias", "alibi"]
        or core_attention_bias_shape is not None
    ):
777
778
        if _flash_attn_is_installed:
            logger.debug("Disabling FlashAttention for pre/post_scale_bias")
779
780
781
782
783
784
785
786
        use_flash_attention = False

    fu_core_attention_bias_type = core_attention_bias_type
    fu_core_attention_bias_shape = core_attention_bias_shape
    fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
    if (
        use_fused_attention
        and core_attention_bias_type == "alibi"
787
        and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
788
789
790
    ):
        fu_core_attention_bias_type = "post_scale_bias"
        fu_core_attention_bias_requires_grad = False
791
792
793
794
795
        if alibi_slopes_shape is None:
            fu_core_attention_bias_shape = "1hss"
        elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
            fu_core_attention_bias_shape = "1hss"
        elif (
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
            len(alibi_slopes_shape) == 2
            and alibi_slopes_shape[0] == batch_size
            and alibi_slopes_shape[1] == num_heads
        ):
            fu_core_attention_bias_shape = "bhss"

    if (
        use_fused_attention
        and fu_core_attention_bias_type == "post_scale_bias"
        and fu_core_attention_bias_shape != "1hss"
    ):
        if fu_core_attention_bias_requires_grad:
            # remove this line when cuDNN adds bwd support for
            # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
            logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
            use_fused_attention = False
        else:
            # max512 backend will only support [1, h, s, s]
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

    # Filter: cuDNN support
    fused_attention_backend = None
    if use_fused_attention:
        q_type = TE_DType[qkv_dtype]
        kv_type = q_type
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            kv_type = q_type
        fused_attention_backend = tex.get_fused_attn_backend(
            q_type,
            kv_type,
            QKVLayout[qkv_layout],
            AttnBiasType[fu_core_attention_bias_type],
            AttnMaskType[attn_mask_type],
            attention_dropout,
            num_heads,
            num_gqa_groups,
            max_seqlen_q,
            max_seqlen_kv,
835
836
            head_dim_qk,
            head_dim_v,
837
838
            window_size[0],
            window_size[1],
839
        )
840
        if fused_attention_backend == FusedAttnBackend["No_Backend"]:
841
842
            logger.debug("Disabling FusedAttention as no backend supports the provided input")
            use_fused_attention = False
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
            fused_attention_backend = None
        if (
            use_fused_attention
            and window_size is not None
            and window_size[0] != -1
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "slidng window attention",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
860
861
862
863
864
865
866
867
            and fu_core_attention_bias_type == "post_scale_bias"
            and fu_core_attention_bias_shape != "1hss"
        ):
            logger.debug(
                "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
                " [1, H, S, S] shape"
            )
            use_fused_attention = False
868
            fused_attention_backend = None
869
870
871
872
873
874
875
876
877
878
879
880
881

    # Filter: Determinism
    # backend                      | deterministic
    # ---------------------------------------------
    # FlashAttention               |
    #     flash-attn >=2.0, <2.4.1 | no
    #     flash-attn >=2.4.1       | yes
    # FusedAttention               |
    #     sub-backend 0            | yes
    #     sub-backend 1            | workspace optimization path and sm90+: yes;
    #                              | otherwise: no
    #     sub-backend 2            | no
    # UnfusedDotProductAttention   | yes
882
883
884
885
886
887
888
889
890
891
    if use_flash_attention and deterministic:
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.4.1")
        elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3:
            logger.warning(
                "Disabling FlashAttention as version <2.4.1 does not support deterministic "
                "execution. To use FlashAttention with deterministic behavior, "
                "please install flash-attn >= 2.4.1."
            )
            use_flash_attention = False
892
893
894
895
896
897
898
899
900
901
902
    if use_fused_attention and deterministic:
        if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (
                device_compute_capability < (9, 0)
                or core_attention_bias_requires_grad
                or cudnn_version < (8, 9, 5)
903
            )
904
905
906
        ):
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
907
908
909

    # All available backends
    available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926

    # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
    # When `FusedAttention` does not support the provided attention params, and `FlashAttention`
    # does, we recommend users to install flash-attn if not installed already.
    if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed:
        logger.warning(
            "flash-attn may provide important feature support or performance improvement."
            " Please install flash-attn %s.",
            _get_supported_versions(
                _flash_attn_version_required,
                _flash_attn_max_version,
            ),
        )
    if use_flash_attention and not _flash_attn_is_installed:
        use_flash_attention = False
        available_backends[0] = False

927
928
929
930
931
932
933
934
935
936
937
938
    logger.debug(
        "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
        " UnfusedDotProductAttention=%s}",
        bool(available_backends[0]),
        bool(available_backends[1]),
        (
            f" (sub-backend {int(fused_attention_backend)})"
            if fused_attention_backend is not None
            else ""
        ),
        bool(available_backends[2]),
    )
939
940
941
942
943
944
945

    # Select FusedAttention for performance
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
    ):
946
        if device_compute_capability >= (9, 0):
947
948
949
950
951
            logger.debug(
                "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                "for performance reasons"
            )
            use_flash_attention = False
952
953
954
955
956
957
958
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["FP8"]
        and _use_flash_attn_3
    ):
        logger.debug(
959
960
            "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
            "in FP8 execution"
961
962
963
        )
        use_flash_attention = False

964
965
966
967
968
969
    # Selected backend
    if use_flash_attention:
        use_fused_attention = False
        use_unfused_attention = False
    elif use_fused_attention:
        use_unfused_attention = False
970
    selected_backend = "NoBackend"
971
972
973
974
975
976
    if use_flash_attention:
        selected_backend = "FlashAttention"
    elif use_fused_attention:
        selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
    elif use_unfused_attention:
        selected_backend = "UnfusedDotProductAttention"
977
    logger.debug("Selected backend = %s", selected_backend)
978

979
980
981
982
983
984
    global _attention_backends
    _attention_backends["use_flash_attention"] = use_flash_attention
    _attention_backends["use_fused_attention"] = use_fused_attention
    _attention_backends["fused_attention_backend"] = fused_attention_backend
    _attention_backends["use_unfused_attention"] = use_unfused_attention
    _attention_backends["backend_selection_requires_update"] = False
985
986
987
988

    return (
        use_flash_attention,
        use_fused_attention,
989
        fused_attention_backend,
990
991
992
993
994
        use_unfused_attention,
        available_backends,
    )


995
class InferenceParams:  # pylint: disable=too-few-public-methods
996
997
    """
    Inference parameters that are passed to the main model in order
998
    to efficiently calculate and store the context during inference.
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
1039

1040

1041
@torch.no_grad()
1042
def get_full_mask(
1043
1044
1045
    max_seqlen_q: int,
    max_seqlen_kv: int,
    attn_mask_type: str = "no_mask",
1046
1047
1048
1049
    attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
    window_size: Tuple[int, int] = None,
    attention_type: str = "self",
    bottom_right_alignment: bool = True,
1050
1051
) -> torch.Tensor:
    """
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
    Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
    `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
    on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::

       attn_mask_type              output shape                                 diagonal alignment
       --------------------------------------------------------------------------------------------
       no_mask                     [1, 1, max_seqlen_q, max_seqlen_kv]          follow bottom_right_alignment
       causal                      [1, 1, max_seqlen_q, max_seqlen_kv]          always top left
       causal_bottom_right         [1, 1, max_seqlen_q, max_seqlen_kv]          always bottom right
       padding                     [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
       padding_causal              [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
       padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
       arbitrary                   same as attention_mask                       follow bottom_right_alignment

    .. note::

    For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
    diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
    i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
    max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
    [[False, False,  True, True], [False, False, False, False]],
    [[False, False, False, True], [False,  True,  True,  True]]), the returned full attention mask has [2, 4, 4]
    shape and is,::

      [[[False, False, False, True],
        [False, False, False, True],
        [ True,  True,  True, True],
        [ True,  True,  True, True]],
       [[False,  True,  True, True],
        [False,  True,  True, True],
        [False,  True,  True, True],
        [False,  True,  True, True]]]
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093

    Parameters
    ----------
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
        "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
1094
    attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
1095
        default = `None`
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
        for the requirements of `attention_mask` for different `attn_mask_type`s.
    window_size: Tuple[int, int], default = `None`
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
        window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
        map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
        `attn_mask_type`.
    attention_type: str, default = "self"
        Attention type, {"self", "cross"}
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
        or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
        specifies "causal" or "causal_bottom_right".
1111
1112
1113

    Returns
    ----------
1114
1115
    attn_mask_type: str
        For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
1116
    attention_mask: torch.Tensor
1117
1118
1119
1120
1121
1122
1123
        The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
    actual_seqlens_q: torch.Tensor
        For padding masks, the actual sequence lengths for queries, in shape [batch_size].
        For other masks, `None`.
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
        For other masks, `None`.
1124
    """
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
    # perform basic checks
    change_type = window_size is not None and (
        window_size[0] != -1 or window_size[1] not in [-1, 0]
    )
    if window_size is None:
        window_size = (-1, -1)
    if "causal" in attn_mask_type:
        window_size = (window_size[0], 0)
    window_size = (
        max_seqlen_kv if window_size[0] == -1 else window_size[0],
        max_seqlen_q if window_size[1] == -1 else window_size[1],
    )

    # apply padding mask
    actual_seqlens_q = None
    actual_seqlens_kv = None
    if "padding" in attn_mask_type:
        if attention_type == "self":
            attention_mask = torch.logical_or(
                attention_mask.squeeze(1).unsqueeze(3), attention_mask
            )
        else:
            attention_mask = torch.logical_or(
                attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
            )
        m = attention_mask.logical_not()
        actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
        actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)

    # apply SWA mask
    mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
        1, 1, max_seqlen_q, 1
    ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
    swa_left = None
    swa_right = None
    if attn_mask_type == "causal_bottom_right" or (
        attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment
    ):
        swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
        swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
    elif attn_mask_type in ["causal", "padding_causal"] or (
        attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment
    ):
        swa_left = mask - window_size[0]
        swa_right = mask + window_size[1]
    elif attn_mask_type == "padding_causal_bottom_right" or (
        attn_mask_type == "padding" and bottom_right_alignment
    ):
        batch_size = attention_mask.shape[0]
        swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
            actual_seqlens_kv - actual_seqlens_q - window_size[0]
        ).view(batch_size, 1, 1, 1)
        swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
            actual_seqlens_kv - actual_seqlens_q + window_size[1]
        ).view(batch_size, 1, 1, 1)
    swa_mask = torch.logical_not(
        torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
    )
1183
    if attention_mask is not None:
1184
1185
1186
1187
1188
1189
1190
1191
1192
        attention_mask = torch.logical_or(swa_mask, attention_mask)
    else:
        attention_mask = swa_mask

    # change mask type
    if change_type:
        attn_mask_type = "arbitrary"

    return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv
1193
1194


1195
1196
1197
1198
1199
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
1200
1201
    actual_seqlens_q: Optional[torch.Tensor] = None,
    actual_seqlens_kv: Optional[torch.Tensor] = None,
1202
1203
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
1204
    bottom_right_alignment: bool = True,
1205
) -> Tuple[torch.Tensor, torch.Tensor]:
1206
    """
1207
1208
1209
1210
1211
1212
1213
1214
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
1215
1216
1217
1218
    actual_seqlens_q: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for queries, in shape [batch_size].
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for keys and values, in shape [batch_size].
1219
1220
1221
1222
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
1223
1224
1225
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the ALiBi bias to the bottom right corner of
        the matrix (`True`) or top left (`False`).
1226

1227
1228
1229
1230
1231
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
1232
1233
1234
1235
1236
1237
        ALiBi bias in FP32 or `bias_dtype`. Its shape is
        (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
        and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
        (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
        [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
        `actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
1261
        elif _alibi_cache["_alibi_slopes"].dim() == 2:
1262
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
1263
1264
1265
        else:
            raise ValueError("ALiBi slopes cannot exceed 2 dimensions.")

1266
        bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1267
            1, 1, max_seqlen_q, 1
1268
1269
        ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
            1, 1, 1, max_seqlen_kv
1270
        )
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        if actual_seqlens_q is None and actual_seqlens_kv is None:
            if bottom_right_alignment:
                bias = bias + max_seqlen_kv - max_seqlen_q
        elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
            batch_size = actual_seqlens_q.shape[0]
            bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
            if bottom_right_alignment:
                bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
        else:
            assert (
                False
            ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
1283
1284
1285
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
1286
        _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
1287
1288
1289
1290
1291
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
1292
1293
1294
1295
1296
1297
1298
1299
1300


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
1301
    reduced_mask = mask.logical_not().sum(dim=1)
1302
1303
1304
1305
1306
1307
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

1308

1309
1310
1311
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
1312
1313
1314
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
1315
1316
1317
1318
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

1319
    reduced_mask = mask.logical_not().sum(dim=1)
1320
1321
1322
1323
1324
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
1325
    indices = mask.logical_not().nonzero()
1326
1327
1328
1329
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
1330
1331
1332
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
1333
1334
1335
1336

    return cu_seqlens, indices


1337
1338
1339
1340
1341
1342
1343
1344
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
1345
1346
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
1347
1348
1349

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
1350
1351
1352
1353
1354
1355
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
1356
1357
1358

    return indices

1359

1360
_cu_seqlens_cache = {}
1361
1362


1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
1383
1384


1385
@torch.compile
1386
1387
1388
1389
1390
1391
1392
1393
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
1394
1395
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1396
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
1397
1398
    if isinstance(tensor, Float8Tensor):
        tensor_data = torch.cat((tensor._data, padding_indice), dim=0)
1399
        gathered_data = torch.gather(tensor_data, 0, indices)
1400

1401
        packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape)
1402
1403
1404
1405
    else:
        tensor = torch.cat((tensor, padding_indice), dim=0)

        packed = torch.gather(tensor, 0, indices)
1406
1407
1408
    return packed


1409
@torch.compile
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


1423
@torch.compile
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


1439
@torch.compile
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
1450
1451
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1452
1453
    if isinstance(tensor, Float8Tensor):
        unpacked.scatter_(0, indices, tensor._data)
1454
1455
        unpacked_data = unpacked[0:-1, :, :]
        unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape)
1456
1457
1458
    else:
        unpacked.scatter_(0, indices, tensor)
        unpacked = unpacked[0:-1, :, :]
1459
1460
1461
    return unpacked


1462
@torch.compile
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


1477
@torch.compile
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
1498

1499
1500
    @staticmethod
    def forward(
1501
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
1502
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
1503
        # pylint: disable=missing-function-docstring
1504
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
1505
        ctx.save_for_backward(indices)
1506
1507
1508
1509
1510
1511
1512
1513
1514
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
1515
        # pylint: disable=missing-function-docstring
1516
        (indices,) = ctx.saved_tensors
1517
        if len(grad_outputs) == 1:
1518
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
1519
        if len(grad_outputs) == 2:
1520
1521
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
1522
1523
1524
1525
1526
1527


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
1528

1529
1530
1531
1532
1533
1534
1535
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
1536
        # pylint: disable=missing-function-docstring
1537
        ctx.save_for_backward(indices)
1538
1539
1540
1541
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
1542
        # pylint: disable=missing-function-docstring
1543
1544
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
1545
1546


1547
1548
1549
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
1550
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
1551
1552
1553
1554
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
1555
1556
1557
1558
1559
1560
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
1561
1562
1563
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
1564
1565
1566
1567
1568
1569
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


1589
@jit_fuser
1590
1591
1592
1593
1594
def flash_attn_fwd_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
1595
1596
    movedim_src: int,
    movedim_dst: int,
1597
):
1598
    """Merge partial outputs of each step in Attention with context parallelism"""
1599
1600
1601
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(
        movedim_src, movedim_dst
    )
1602
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
1603
    out_corrected = out_per_step * softmax_lse_corrected_exp
1604
1605
1606
    out.add_(out_corrected)


1607
@jit_fuser
1608
1609
1610
1611
def flash_attn_fwd_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
1612
    """Merge softmax stats of each step in Attention with context parallelism"""
1613
1614
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
1615
    new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
1616
    softmax_lse.copy_(new_scale)
1617
1618


1619
1620
@jit_fuser
def get_cu_seqlens_on_cp_rank(
1621
1622
1623
1624
1625
1626
    cu_seqlens: torch.Tensor,
    cu_seqlens_padded_on_cp_rank: torch.Tensor,
    cp_size: int,
    cp_rank: int,
    first_half: bool,
    second_half: bool,
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
@torch.compile
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
    To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
    before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
    sequence chunk ids for reordering.
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
    if to_contiguous:
        for rank in range(cp_size):
            chunk_ids[rank] = 2 * rank
            chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
    else:
        for rank in range(cp_size):
            chunk_ids[2 * rank] = rank
            chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
    return chunk_ids


@torch.compile
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
    """Reorder sequence chunk for A2A communication."""
    if before_attn:
        # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
        # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
        x = x.movedim(0, seq_dim).contiguous()
        # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
        # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
        x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
        # reorder the sequence chunks
        x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
    else:
        # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
        # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
        x = x.movedim(seq_dim, 0).contiguous()
        # reorder the sequence chunks
        x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
        # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
        # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
        x = x.view(cp_size, 2, *x.shape[1:])
    return x


def flash_attn_a2a_communicate(
    a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
    chunk_ids_for_a2a: torch.Tensor,
    seq_dim: int,
    cp_size: int,
    cp_group: dist_group_type,
    cp_stream: torch.cuda.Stream,
    before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
    """A2A communication for context parallelism."""
    a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
    a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
    if before_attn:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # reorder the sequence chunks
                    x = reorder_seq_chunks_for_a2a(
                        x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
                    )
                    # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
                    # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
                    a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, s, np, hn] -> [b, s, cp, np//cp, hn]
                # or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
                x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
                # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
                # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
                a2a_inputs[i] = x.movedim(-3, 0).contiguous()
    else:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
                # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
                x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
                # reorder the sequence chunks
                a2a_inputs[i] = reorder_seq_chunks_for_a2a(
                    x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
                    # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
                    x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
                    # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
                    # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
                    a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
    torch.cuda.current_stream().wait_stream(cp_stream)
    return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs


1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
    """Get the list of quantizers used in attention from the quantizers list."""
    if not fp8:
        num_of_nones = 8 if cp_specific_quantizers else 6
        return [None] * num_of_nones
    QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
    QKV_quantizer.internal = True
    QKV_quantizer.set_usage(rowwise=True, columnwise=False)
    O_quantizer = quantizers["scaling_fwd"][META_O]
    O_quantizer.set_usage(rowwise=True, columnwise=False)
    S_quantizer = quantizers["scaling_fwd"][META_S]
    S_quantizer.internal = True
    S_quantizer.set_usage(rowwise=True, columnwise=False)
    dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
    dQKV_quantizer.interal = True
    dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
    dO_quantizer = quantizers["scaling_bwd"][META_DO]
    dO_quantizer.set_usage(rowwise=True, columnwise=False)
    dO_quantizer.internal = True
    dP_quantizer = quantizers["scaling_bwd"][META_DP]
    dP_quantizer.set_usage(rowwise=True, columnwise=False)
    dP_quantizer.interal = True
    dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
    dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False)
    dQKV_CP_quantizer.internal = True
    O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
    O_CP_quantizer.set_usage(rowwise=True, columnwise=False)

    if cp_specific_quantizers:
        return (
            QKV_quantizer,
            O_quantizer,
            O_CP_quantizer,
            S_quantizer,
            dQKV_quantizer,
            dQKV_CP_quantizer,
            dO_quantizer,
            dP_quantizer,
        )

    return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer


1800
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
1801
    """
1802
1803
1804
    Attention implementation with context parallelism. Exchange KV between CP ranks
    with P2P in ring topology. Split attention compute into multiple steps, and overlap
    current-step compute with next-step communication.
1805
1806
1807
1808
1809

    This implementation also supports hierarchical CP, which parallelizes attention
    heads in low-level CP groups and parallelizes sequence dimension in high-level CP
    groups. For more details, please refer to `LongVILA <https://arxiv.org/abs/2408.10188>`_
    and `USP <https://arxiv.org/abs/2405.07719>`_.
1810
1811
1812
    """

    @staticmethod
1813
1814
1815
1816
1817
1818
1819
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
1820
        cu_seqlens_kv,
1821
        max_seqlen_q,
1822
        max_seqlen_kv,
1823
1824
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
1825
1826
1827
1828
1829
1830
1831
1832
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
1833
1834
        fp8,
        fp8_meta,
1835
1836
1837
        cp_group,
        cp_global_ranks,
        cp_stream,
1838
        quantizers,
1839
    ):
1840
        # pylint: disable=missing-function-docstring
1841
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
1842
1843
1844
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
        if isinstance(cp_group, list):
            assert (
                qkv_format != "thd"
            ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
            assert attn_bias_type == "no_bias", (
                f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
                " yet!"
            )
            cp_group_a2a = cp_group[0]
            cp_size_a2a = get_distributed_world_size(cp_group_a2a)
            rank_a2a = get_distributed_rank(cp_group_a2a)
            cp_group = cp_group[1]
        else:
            cp_group_a2a = None
            cp_size_a2a = 1
            rank_a2a = 0

1862
1863
        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
1864
1865
        send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
1866
1867
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1868
1869
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
1870

1871
        seq_dim = None
1872
        if qkv_format in ["bshd", "sbhd"]:
1873
            seq_dim = qkv_format.index("s")
1874
1875
1876
1877
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

1878
1879
1880
1881
1882
1883
        pad_between_seqs_q = cu_seqlens_q_padded is not None and not torch.equal(
            cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]
        )
        pad_between_seqs_kv = cu_seqlens_kv_padded is not None and not torch.equal(
            cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]
        )
1884
1885
        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
1886
1887
1888
1889
1890
1891
        cu_seqlens_q_padded = (
            None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // cp_size
        )
        cu_seqlens_kv_padded = (
            None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // cp_size
        )
1892
1893
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
1894

1895
        fused_attn_backend = None
1896
1897
1898
        qkv_dtype = q.dtype
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
        is_output_fp8 = False
        if fp8:
            is_output_fp8 = fp8_meta["recipe"].fp8_mha

        (
            QKV_quantizer,
            O_quantizer,
            O_CP_quantizer,
            S_quantizer,
            dQKV_quantizer,
            dQKV_CP_quantizer,
            dO_quantizer,
            dP_quantizer,
        ) = get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True)

1914
1915
1916
        if fp8:
            if use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
1917

1918
1919
1920
1921
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
1922
                if not is_input_fp8:
1923
1924
                    q_f16, k_f16, v_f16 = q, k, v
                    if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
1925
                        q = QKV_quantizer(q_f16)
1926
                    if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
1927
                        k, v = [QKV_quantizer(x) for x in [k_f16, v_f16]]
1928
                fp8_meta_kwargs = {}
1929
1930
                fp8_meta_kwargs["s_quantizer"] = S_quantizer
                fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer  # partial result quantizer
1931
1932
1933
1934
1935
1936
1937
1938
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            q_f16 = q
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

1939
1940
1941
1942
1943
        if fp8:
            q = q._data
            k = k._data
            v = v._data

1944
1945
        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True)
1946

1947
1948
1949
1950
1951
            q, k, v = flash_attn_a2a_communicate(
                [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
            )
            if not fp8:
                q_f16 = q
1952
            elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
1953
                q_f16 = q
1954
                q = QKV_quantizer(q_f16)._data
1955

1956
1957
1958
        assert qkv_format == "thd" or (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"
1959
        if causal:
1960
1961
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
1962
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
1963
1964
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
1965
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
1966
        if attn_bias is not None:
1967
            assert len(attn_bias.shape) == 4, (
1968
1969
1970
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
1971
1972
1973
            assert (
                attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
            ), "Sequence length does not meet divisible requirements!"
1974
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
1975
1976
1977
1978
1979
1980
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
1981
1982
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
1983
1984
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
1985
            )
1986
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
1987

1988
1989
1990
1991
1992
1993
1994
        softmax_lse_in_packed_format = False
        if qkv_format == "thd":
            if use_fused_attention:
                softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
            else:
                softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3

1995
        flash_attn_fwd = None
1996
1997
1998
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
1999
2000
2001
2002
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
2003
2004
                fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
            else:
2005
2006
2007
2008
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
2009
2010
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
2011
                if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3:
2012
                    fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
2013
2014
2015
                elif _flash_attn_2_7_0_plus:
                    fa_forward_kwargs["window_size_left"] = -1
                    fa_forward_kwargs["window_size_right"] = 0 if causal else -1
2016
2017
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
2018
                if _flash_attn_2_5_7_plus and qkv_format == "thd":
2019
                    fa_forward_kwargs["block_table"] = None
2020
2021
                if _flash_attn_2_6_0_plus:
                    fa_forward_kwargs["softcap"] = 0.0
2022

2023
2024
2025
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
2026
        attn_bias_inputs = [None, None]
2027
2028
2029
2030
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
2031
        attn_biases = [None for _ in range(cp_size)]
2032
2033
2034
2035
2036
2037
2038

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
2039
        if qkv_format in ["bshd", "sbhd"]:
2040
2041
2042
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
2043
2044
        send_recv_reqs = [[], []]

2045
2046
        softmax_lse_ = None
        out = None
2047
        for i in range(cp_size + 1):
2048
            if i < cp_size:
2049
                with torch.cuda.stream(flash_attn_streams[i % 2]):
2050
                    # wait until KV is received
2051
                    for req in send_recv_reqs[(i + 1) % 2]:
2052
2053
                        req.wait()

2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

2066
                    if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
2067
2068
2069
                        kv_inputs[i % 2] = p2p_comm_buffers[i]
                    else:
                        # KV exchange is in BF16/FP16, cast received KV in each step
2070
                        kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])
2071
2072
                    if causal:
                        if i == 0:
2073
2074
2075
2076
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
2077
                            elif use_fused_attention or qkv_format == "thd":
2078
2079
2080
2081
2082
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
2083
                            elif use_fused_attention or qkv_format == "thd":
2084
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
2101
                            if use_fused_attention:
2102
2103
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2104
2105
2106
2107
2108
2109
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
2110
                                    ).contiguous()
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )

2134
2135
2136
2137
2138
2139
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2140
2141
2142
2143
2144
                                    q_part,
                                    k_part,
                                    v_part,
                                    fake_dtype=qkv_dtype,
                                    fused_attention_backend=fused_attn_backend,
2145
2146
2147
2148
2149
2150
2151
2152
2153
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
2154
                                )
2155
2156
2157
2158
2159
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2160
                            else:
2161
2162
2163
2164
2165
2166
2167
2168
                                fa_forward_args_thd = []
                                if qkv_format == "thd":
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q,
                                        max_seqlen_kv,
                                    ]
2169
                                fa_outputs = flash_attn_fwd(
2170
                                    q_inputs[i % 2],
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2182
                                    causal=True,
2183
                                    **fa_forward_kwargs,
2184
                                )
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
                                if not _flash_attn_2_7_0_plus:
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[3]
2195
                        elif i <= rank:
2196
2197
2198
2199
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
2200
                            elif use_fused_attention or qkv_format == "thd":
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
2211
                            elif use_fused_attention or qkv_format == "thd":
2212
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][0]
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
                                # [2, t, np, hn] -> [2, t/2, np, hn]
                                kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                    kv_inputs[i % 2], cu_seqlens_kv_padded, 0
                                )
2229
                            if use_fused_attention:
2230
                                kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
2231
2232
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2233
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
2256
2257
2258
2259
2260
2261
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv // 2,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2262
2263
2264
2265
                                    q_part,
                                    k_part,
                                    v_part,
                                    qkv_dtype,
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=(
                                        None
                                        if cu_seqlens_kv_padded is None
                                        else cu_seqlens_kv_padded // 2
                                    ),
                                    **fp8_meta_kwargs,
2280
                                )
2281
2282
2283
2284
2285
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2286
                            else:
2287
                                fa_forward_args_thd = []
2288
                                if qkv_format == "thd":
2289
2290
2291
2292
2293
2294
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q,
                                        max_seqlen_kv // 2,
                                    ]
2295
2296
2297
                                if _use_flash_attn_3 or (
                                    _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                                ):
2298
                                    fa_forward_kwargs["window_size"] = (-1, -1)
2299
2300
2301
                                elif _flash_attn_2_7_0_plus:
                                    fa_forward_kwargs["window_size_left"] = -1
                                    fa_forward_kwargs["window_size_right"] = -1
2302
                                fa_outputs = flash_attn_fwd(
2303
                                    q_inputs[i % 2],
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2315
                                    causal=False,
2316
                                    **fa_forward_kwargs,
2317
                                )
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
                                if not _flash_attn_2_7_0_plus:
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[3]
2328
                        else:
2329
2330
2331
2332
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
2333
                            elif use_fused_attention or qkv_format == "thd":
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
2344
                            elif use_fused_attention or qkv_format == "thd":
2345
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                q_inputs[i % 2] = q[:, 1, ...]
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                                q_inputs[i % 2] = q[1]
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                # [t, np, hn] -> [t/2, np, hn]
                                q_inputs[i % 2] = tex.thd_read_half_tensor(
                                    q, cu_seqlens_q_padded, 1
                                )
2365
                            if use_fused_attention:
2366
                                q_inputs[i % 2] = q_inputs[i % 2].contiguous()
2367
2368
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2369
2370
2371
2372
2373
2374
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
2375
                                    ).contiguous()
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
2398
2399
2400
2401
2402
2403
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q // 2,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2404
2405
2406
2407
                                    q_part,
                                    k_part,
                                    v_part,
                                    qkv_dtype,
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=(
                                        None
                                        if cu_seqlens_q_padded is None
                                        else cu_seqlens_q_padded // 2
                                    ),
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
2422
                                )
2423
2424
2425
2426
2427
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2428
                            else:
2429
                                fa_forward_args_thd = []
2430
                                if qkv_format == "thd":
2431
2432
2433
2434
2435
2436
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q // 2,
                                        max_seqlen_kv,
                                    ]
2437
2438
2439
                                if _use_flash_attn_3 or (
                                    _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                                ):
2440
                                    fa_forward_kwargs["window_size"] = (-1, -1)
2441
2442
2443
                                elif _flash_attn_2_7_0_plus:
                                    fa_forward_kwargs["window_size_left"] = -1
                                    fa_forward_kwargs["window_size_right"] = -1
2444
                                fa_outputs = flash_attn_fwd(
2445
                                    q_inputs[i % 2],
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2457
                                    causal=False,
2458
                                    **fa_forward_kwargs,
2459
                                )
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
                                if not _flash_attn_2_7_0_plus:
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[3]
2470
                    else:
2471
2472
2473
2474
                        if pad_between_seqs_q:
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
2475
                        elif use_fused_attention or qkv_format == "thd":
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                        if pad_between_seqs_kv:
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
2486
                        elif use_fused_attention or qkv_format == "thd":
2487
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2488
                        if use_fused_attention:
2489
2490
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
2491
2492
2493
2494
2495
2496
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
2497
                                ).contiguous()
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519

                            q_part = q
                            k_part = (
                                kv_inputs[i % 2][..., 0, :, :]
                                if qkv_format in ["bshd", "sbhd"]
                                else kv_inputs[i % 2][0]
                            )
                            v_part = (
                                kv_inputs[i % 2][..., 1, :, :]
                                if qkv_format in ["bshd", "sbhd"]
                                else kv_inputs[i % 2][1]
                            )
                            if fp8:
                                q_part = QKV_quantizer.create_tensor_from_data(
                                    q_part, fake_dtype=qkv_dtype, internal=True
                                )
                                k_part = QKV_quantizer.create_tensor_from_data(
                                    k_part, fake_dtype=qkv_dtype, internal=True
                                )
                                v_part = QKV_quantizer.create_tensor_from_data(
                                    v_part, fake_dtype=qkv_dtype, internal=True
                                )
2520
2521
2522
2523
2524
2525
                            out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                is_training,
                                max_seqlen_q,
                                max_seqlen_kv,
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
2526
2527
2528
2529
                                q_part,
                                k_part,
                                v_part,
                                qkv_dtype,
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
                                fused_attn_backend,
                                attn_scale=softmax_scale,
                                dropout=dropout_p,
                                qkv_layout=qkv_layout,
                                attn_mask_type=attn_mask_type,
                                attn_bias_type=attn_bias_type,
                                attn_bias=attn_bias_inputs[i % 2],
                                cu_seqlens_q_padded=cu_seqlens_q_padded,
                                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                **fp8_meta_kwargs,
2540
                            )
2541
2542
2543
2544
2545
                            if fp8:
                                softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                            else:
                                softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                attn_biases[i] = rest[0] if len(rest) > 0 else None
2546
                        else:
2547
2548
2549
2550
2551
2552
2553
2554
                            fa_forward_args_thd = []
                            if qkv_format == "thd":
                                fa_forward_args_thd = [
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                ]
2555
                            fa_outputs = flash_attn_fwd(
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                *fa_forward_args_thd,
2568
                                causal=False,
2569
                                **fa_forward_kwargs,
2570
                            )
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
                            if not _flash_attn_2_7_0_plus:
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[7]
                            else:
                                out_per_step[i] = fa_outputs[0]
                                softmax_lse_per_step[i] = fa_outputs[1]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[3]
2581
2582
2583
2584

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
2585
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
2586

2587
                if use_fused_attention:
2588
2589
                    # [b, np, sq, 1] -> [b, np, sq] or
                    # [t, np, 1] -> [t, np]
2590
                    softmax_lse_per_step[i - 1].squeeze_(-1)
2591
2592
2593
2594
                    if softmax_lse_in_packed_format:
                        softmax_lse_per_step[i - 1] = (
                            softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
                        )
2595

2596
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
2597
                    if fp8:
2598
                        out_per_step[i - 1] = out_per_step[i - 1].dequantize()
2599
                    if i == 1:
2600
                        out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
2601
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
2602
                        if causal and qkv_format != "thd":
2603
                            # [b, np, sq] -> [b, np, 2, sq//2]
2604
                            softmax_lse_ = softmax_lse.view(
2605
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
2606
                            )
2607
2608
2609
2610
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
2611
                    else:
2612
                        if qkv_format == "thd":
2613
                            tex.thd_second_half_lse_correction(
2614
2615
2616
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
2617
                                softmax_lse_in_packed_format,
2618
                            )
2619
                        else:
2620
2621
2622
                            flash_attn_fwd_softmax_lse_correction(
                                softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
                            )
2623
2624

                if i < cp_size:
2625
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
2626
2627
2628

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

2629
2630
2631
2632
        second_half_lse_seqlen = None
        if causal and rank < (cp_size - 1):
            second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1]

2633
2634
        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
2635
            if i <= rank or not causal:
2636
                if qkv_format in ["bshd", "sbhd"]:
2637
2638
2639
2640
2641
                    flash_attn_fwd_out_correction(
                        out.view(*out_per_step[i].shape),
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2642
2643
                        0 if softmax_lse_in_packed_format else 2,
                        2 if softmax_lse_in_packed_format else seq_dim,
2644
                    )
2645
                elif qkv_format == "thd":
2646
2647
2648
2649
2650
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2651
                        cu_seqlens_q_padded,
2652
                        False,
2653
                        softmax_lse_in_packed_format,
2654
                    )
2655
            else:
2656
                if qkv_format in ["bshd", "sbhd"]:
2657
                    out_ = out.select(seq_dim, 1)
2658
2659
2660
2661
2662
                    flash_attn_fwd_out_correction(
                        out_,
                        out_per_step[i],
                        softmax_lse_[..., 1, :],
                        softmax_lse_per_step[i],
2663
2664
                        0 if softmax_lse_in_packed_format else 2,
                        2 if softmax_lse_in_packed_format else seq_dim,
2665
                    )
2666
                elif qkv_format == "thd":
2667
2668
2669
2670
2671
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2672
                        cu_seqlens_q_padded,
2673
                        True,
2674
                        softmax_lse_in_packed_format,
2675
                    )
2676
2677

        kv = p2p_comm_buffers[-1]
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
        if qkv_format == "bshd":
            out = out.view(out.shape[0], -1, *out.shape[-2:])
            ctx.batch_size = out.shape[0]
        elif qkv_format == "sbhd":
            out = out.view(-1, *out.shape[-3:])
            ctx.batch_size = out.shape[1]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
            out = flash_attn_a2a_communicate(
                out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
            )
            if use_fused_attention:
                if qkv_format == "bshd":
                    # [b*s, np, hn] -> [b, s, np, hn]
                    out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                elif qkv_format == "sbhd":
                    # [s*b, np, hn] -> [s, b, np, hn]
                    out = out.view(-1, ctx.batch_size, *out.shape[-2:])
        elif not use_fused_attention:
2698
            out = out.view(-1, *out.shape[-2:])
2699

2700
        out_fp8 = None
2701
        out_f16 = out.to(qkv_dtype)
2702

2703
        if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
2704
2705
2706
            out_fp8 = O_quantizer(out_f16)  # final result

        out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16
2707
2708

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
2709
            q_save, kv_save, out_save = q, kv, out_fp8._data
2710
        elif fp8 and is_input_fp8:
2711
            q_save, kv_save, out_save = q, k, out_f16
2712
        else:
2713
            q_f16 = q_f16.view(q.shape)
2714
2715
            q_save, kv_save, out_save = q_f16, kv, out_f16

2716
        tensors_to_save, tensor_objects = prepare_for_saving(
2717
2718
2719
            q_save,
            kv_save,
            out_save,
2720
            softmax_lse,
2721
2722
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
2723
2724
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
2725
2726
            *rng_states,
            *attn_biases,
2727
        )
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects

        ctx.qkv_dtype = qkv_dtype
        ctx.QKV_quantizer = QKV_quantizer
        ctx.O_quantizer = O_quantizer
        ctx.O_CP_quantizer = O_CP_quantizer
        ctx.S_quantizer = S_quantizer
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.qkv_dtype = qkv_dtype

2742
2743
2744
        ctx.cp_group_a2a = cp_group_a2a
        ctx.cp_size_a2a = cp_size_a2a
        ctx.rank_a2a = rank_a2a
2745
2746
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
2747
        ctx.cp_stream = cp_stream
2748
2749
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
2750
        ctx.max_seqlen_kv = max_seqlen_kv
2751
        ctx.softmax_scale = softmax_scale
2752
        ctx.qkv_format = qkv_format
2753
        ctx.attn_mask_type = attn_mask_type
2754
2755
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
2756
        ctx.deterministic = deterministic
2757
        ctx.use_fused_attention = use_fused_attention
2758
        ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format
2759
        ctx.second_half_lse_seqlen = second_half_lse_seqlen
2760
2761
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
2762
2763
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
2764
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
2765

2766
        return out_ret
2767
2768
2769

    @staticmethod
    def backward(ctx, dout):
2770
        # pylint: disable=missing-function-docstring
2771
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
2772
2773
2774
        cp_size_a2a = ctx.cp_size_a2a
        rank_a2a = ctx.rank_a2a

2775
2776
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
2777
2778
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
2779
2780
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

2781
2782
2783
2784
2785
2786
2787
2788
2789
        saved_tensors = ctx.saved_tensors

        q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
            restore_from_saved(ctx.tensor_objects, saved_tensors)
        )
        cu_seqlens_q_per_step = other_tensors[:cp_size]
        cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
        rng_states = other_tensors[cp_size * 2 : cp_size * 3]
        attn_biases = other_tensors[cp_size * 3 : cp_size * 4]
2790

2791
2792
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
2793
2794

        seq_dim = None
2795
        if ctx.qkv_format in ["bshd", "sbhd"]:
2796
            seq_dim = ctx.qkv_format.index("s")
2797
2798
2799
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
2800

2801
        if attn_biases[0] is not None:
2802
2803
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
2804
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
2805
2806
2807
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
2808
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
2809
2810
2811
            )
        else:
            attn_dbias = None
2812
            attn_dbias_ = None
2813

2814
2815
        softmax_lse_ = None
        if causal and ctx.second_half_lse_seqlen is not None:
2816
            if ctx.qkv_format == "thd":
2817
                softmax_lse_ = tex.thd_read_second_half_lse(
2818
2819
2820
2821
                    softmax_lse,
                    cu_seqlens_q_padded,
                    ctx.softmax_lse_in_packed_format,
                    ctx.second_half_lse_seqlen,
2822
                )
2823
2824
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
2825
2826
2827
                softmax_lse_ = softmax_lse.view(
                    *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
                )
2828
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
2829
2830
2831
2832
2833
2834
            if ctx.use_fused_attention:
                if ctx.softmax_lse_in_packed_format:
                    softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous()
                # [b, np, sq//2] -> [b, np, sq//2, 1] or
                # [t//2, np] -> [t//2, np, 1]
                softmax_lse_.unsqueeze_(-1)
2835
        if ctx.use_fused_attention:
2836
2837
2838
2839
            if ctx.softmax_lse_in_packed_format:
                softmax_lse = softmax_lse.transpose(0, 1).contiguous()
            # [b, np, sq] -> [b, np, sq, 1] or
            # [t, np] -> [t, np, 1]
2840
            softmax_lse.unsqueeze_(-1)
2841

2842
        dq = None
2843
        dout_dtype = dout.dtype
2844
2845
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
2846
2847
2848
        if ctx.fp8:
            if ctx.use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
2849

2850
2851
2852
                dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
                dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
                dkv_fp8_ = torch.empty_like(dkv_fp8)
2853
                if ctx.is_output_fp8:
2854
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
2855
                    fused_attn_dqkv_dtype = dout._fp8_dtype
2856
2857
                    dout = dout._data
                else:
2858
2859
2860
                    dout = ctx.dO_quantizer(dout)
                    fused_attn_dqkv_dtype = dout._fp8_dtype
                    dout = dout._data
2861
2862
                p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
                fp8_meta_kwargs = {}
2863
2864
2865
                fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
                fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer
                fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_CP_quantizer
2866
2867
2868
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
2869
            if ctx.fp8_meta is not None and ctx.is_input_fp8:
2870
2871
2872
2873
2874
2875
2876
                q = ctx.QKV_quantizer.create_tensor_from_data(
                    q, fake_dtype=ctx.qkv_dtype, internal=True
                )
                kv = ctx.QKV_quantizer.create_tensor_from_data(
                    kv, fake_dtype=ctx.qkv_dtype, internal=True
                )
                q, kv = q.dequantize(), kv.dequantize()
2877
                if cp_size_a2a == 1:
2878
                    dout = dout.dequantize()
2879
2880
2881
2882
2883
2884
2885
2886
            dq = torch.empty_like(q)
            p2p_comm_buffers = [
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            ]
            p2p_comm_buffers[0][0].copy_(kv)
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
2887
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
2888
2889
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
        if cp_size_a2a > 1:
            if not ctx.use_fused_attention:
                out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                dout = dout.view(*out.shape)
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True)
            out, dout = flash_attn_a2a_communicate(
                [out, dout],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                True,
            )
2904
            if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
2905
2906
2907
                dout = ctx.dO_quantizer.create_tensor_from_data(data=dout, internal=True)
                dout = dout.dequantize()
                dout = dout._data
2908

2909
2910
2911
2912
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        send_recv_reqs = []

2913
        flash_attn_bwd = None
2914
2915
2916
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
2917
2918
2919
2920
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
2921
2922
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
2923
2924
2925
2926
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
2927
2928
2929
2930
2931
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
2932
2933
                if _flash_attn_2_6_0_plus:
                    fa_backward_kwargs["softcap"] = 0.0
2934

2935
2936
2937
2938
2939
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

2940
2941
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
            if ctx.fp8:
                if i < cp_size - 1:
                    send_recv_reqs = flash_attn_p2p_communicate(
                        rank,
                        send_tensor[0],
                        send_dst,
                        recv_tensor[0],
                        recv_src,
                        ctx.cp_group,
                        batch_p2p_comm,
                    )
                else:
                    dkv_a2a_req = torch.distributed.all_to_all_single(
                        dkv_fp8,
                        dkv_fp8_,
                        group=ctx.cp_group,
                        async_op=True,
                    )
                    send_recv_reqs = [dkv_a2a_req]
            else:
                if i == 0:
                    send_tensor = send_tensor[0]
                    recv_tensor = recv_tensor[0]
                if i == (cp_size - 1):
                    send_tensor = send_tensor[1]
                    recv_tensor = recv_tensor[1]
                send_recv_reqs = flash_attn_p2p_communicate(
                    rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
                )
2971

2972
            kv = p2p_comm_buffers[i % 2][0]
2973
2974
            q_, kv_, out_, dout_ = None, None, None, None
            dq_, dk_, dv_ = None, None, None
2975
            # In reversed order of fwd
2976
            if causal:
2977
                if i == (cp_size - 1):
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        q_, kv_, out_, dout_ = q, kv, out, dout
2992
                    if ctx.use_fused_attention:
2993
2994
2995
2996
2997
2998
2999
3000
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
3001
                        if attn_dbias is not None:
3002
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
                                dout_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
3025
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3026
                            ctx.max_seqlen_q,
3027
3028
3029
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
3030
3031
3032
3033
3034
3035
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
3036
                            fused_attn_dqkv_dtype,
3037
                            aux_ctx_tensors,
3038
                            fused_attn_backend,
3039
3040
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3041
3042
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
3043
                            qkv_layout=qkv_layout,
3044
                            attn_mask_type=ctx.attn_mask_type,
3045
                            attn_bias_type=ctx.attn_bias_type,
3046
3047
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
3048
                        )
3049
3050
3051
3052
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
3053
                    else:
3054
                        dq_ = torch.empty_like(q_)
3055
                        dkv_ = torch.empty_like(kv_)
3056
3057
3058
3059
3060
3061
3062
3063
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q,
                                ctx.max_seqlen_kv,
                            ]
3064
3065
3066
                        if _use_flash_attn_3 or (
                            _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                        ):
3067
                            fa_backward_kwargs["window_size"] = (-1, 0)
3068
3069
3070
                        elif _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = 0
3071
3072
3073
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
3074
3075
                            dout_,
                            q_,
3076
3077
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3078
3079
3080
                            out_,
                            softmax_lse,
                            dq_,
3081
3082
3083
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
3084
3085
                            causal=True,
                            **fa_backward_kwargs,
3086
                        )
3087
                elif i >= (cp_size - rank - 1):
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                        kv_ = kv[:, 0]
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                        kv_ = kv[0]
                    elif ctx.qkv_format == "thd":
                        q_, out_, dout_ = q, out, dout
                        # [2, t, np, hn] -> [2, t/2, np, hn]
                        kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
3104
                    if ctx.use_fused_attention:
3105
                        kv_ = kv_.contiguous()
3106
3107
3108
3109
3110
3111
3112
3113
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
3114
                        if attn_dbias is not None:
3115
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
                                dout_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
3138
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3139
                            ctx.max_seqlen_q,
3140
3141
3142
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
3143
3144
3145
3146
3147
3148
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
3149
                            fused_attn_dqkv_dtype,
3150
                            aux_ctx_tensors,
3151
                            fused_attn_backend,
3152
3153
3154
3155
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
3156
3157
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
3158
                            qkv_layout=qkv_layout,
3159
                            attn_mask_type="padding" if padding else "no_mask",
3160
                            attn_bias_type=ctx.attn_bias_type,
3161
3162
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
3163
                        )
3164
3165
3166
3167
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
3168
                    else:
3169
                        dq_ = torch.empty_like(q_)
3170
                        dkv_ = torch.empty_like(kv_)
3171
3172
3173
3174
3175
3176
3177
3178
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q,
                                ctx.max_seqlen_kv // 2,
                            ]
3179
3180
3181
                        if _use_flash_attn_3 or (
                            _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                        ):
3182
                            fa_backward_kwargs["window_size"] = (-1, -1)
3183
3184
3185
                        if _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = -1
3186
3187
3188
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
3189
3190
                            dout_,
                            q_,
3191
3192
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3193
3194
3195
                            out_,
                            softmax_lse,
                            dq_,
3196
3197
3198
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
3199
3200
                            causal=False,
                            **fa_backward_kwargs,
3201
3202
                        )
                else:
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                        q_, out_, dout_ = q[1], out[1], dout[1]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        # [t, np, hn] -> [t/2, np, hn]
                        q_, out_, dout_ = [
                            tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1)
                            for x in [q, out, dout]
                        ]
                        kv_ = kv
3220
                    if ctx.use_fused_attention:
3221
                        q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]]
3222
3223
3224
3225
3226
3227
3228
3229
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse_,
                                softmax_lse_,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
3230
                        if attn_dbias is not None:
3231
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254

                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
                                dout_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
3255
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3256
                            ctx.max_seqlen_q // 2,
3257
3258
3259
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
3260
3261
3262
3263
3264
3265
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
3266
                            fused_attn_dqkv_dtype,
3267
                            aux_ctx_tensors,
3268
                            fused_attn_backend,
3269
3270
3271
3272
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3273
3274
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
3275
                            qkv_layout=qkv_layout,
3276
                            attn_mask_type="padding" if padding else "no_mask",
3277
                            attn_bias_type=ctx.attn_bias_type,
3278
3279
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
3280
                        )
3281
3282
3283
3284
3285
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data

3286
                    else:
3287
                        dq_ = torch.empty_like(q_)
3288
                        dkv_ = torch.empty_like(kv_)
3289
                        fa_backward_args_thd = []
3290
                        if ctx.qkv_format == "thd":
3291
3292
3293
3294
3295
3296
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q // 2,
                                ctx.max_seqlen_kv,
                            ]
3297
3298
3299
                        if _use_flash_attn_3 or (
                            _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                        ):
3300
                            fa_backward_kwargs["window_size"] = (-1, -1)
3301
3302
3303
                        elif _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = -1
3304
3305
3306
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
3307
3308
                            dout_,
                            q_,
3309
3310
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3311
3312
3313
                            out_,
                            softmax_lse_,
                            dq_,
3314
3315
3316
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
3317
3318
                            causal=False,
                            **fa_backward_kwargs,
3319
3320
3321
                        )
            else:
                if ctx.use_fused_attention:
3322
3323
3324
3325
                    if ctx.fp8:
                        aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
                    else:
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
3326
                    if attn_dbias is not None:
3327
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
                    q_part = q
                    k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
                    v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
                    out_part = out
                    dout_part = dout

                    if ctx.fp8:
                        q_part = ctx.QKV_quantizer.create_tensor_from_data(
                            q_part, fake_dtype=ctx.qkv_dtype
                        )
                        k_part = ctx.QKV_quantizer.create_tensor_from_data(
                            k_part, fake_dtype=ctx.qkv_dtype
                        )
                        v_part = ctx.QKV_quantizer.create_tensor_from_data(
                            v_part, fake_dtype=ctx.qkv_dtype
                        )
                        out_part = ctx.O_quantizer.create_tensor_from_data(
                            out_part, fake_dtype=ctx.qkv_dtype
                        )
                        dout_part = ctx.dO_quantizer.create_tensor_from_data(
                            dout_part, fake_dtype=ctx.qkv_dtype
                        )
3350
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3351
                        ctx.max_seqlen_q,
3352
3353
3354
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
3355
3356
3357
3358
3359
3360
                        q_part,
                        k_part,
                        v_part,
                        out_part,
                        dout_part,
                        ctx.qkv_dtype,
3361
                        fused_attn_dqkv_dtype,
3362
                        aux_ctx_tensors,
3363
                        fused_attn_backend,
3364
3365
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3366
3367
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
3368
                        qkv_layout=qkv_layout,
3369
                        attn_mask_type=ctx.attn_mask_type,
3370
                        attn_bias_type=ctx.attn_bias_type,
3371
3372
                        deterministic=ctx.deterministic,
                        **fp8_meta_kwargs,
3373
                    )
3374
3375
3376
3377
3378
3379

                    if ctx.fp8:
                        dq_ = dq_._data
                        dk_ = dk_._data
                        dv_ = dv_._data

3380
                else:
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
                    dq_ = torch.empty_like(q)
                    dkv_ = torch.empty_like(kv)
                    fa_backward_args_thd = []
                    if ctx.qkv_format == "thd":
                        fa_backward_args_thd = [
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_kv,
                        ]
3391
                    if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
3392
                        fa_backward_kwargs["window_size"] = (-1, -1)
3393
3394
3395
                    elif _flash_attn_2_7_0_plus:
                        fa_backward_kwargs["window_size_left"] = -1
                        fa_backward_kwargs["window_size_right"] = -1
3396
3397
3398
                    if not _use_flash_attn_3:
                        fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                    flash_attn_bwd(
3399
3400
3401
3402
3403
                        dout,
                        q,
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
                        out,
3404
3405
                        softmax_lse,
                        dq_,
3406
3407
3408
                        dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                        dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                        *fa_backward_args_thd,
3409
3410
                        causal=False,
                        **fa_backward_kwargs,
3411
3412
                    )

3413
3414
            if ctx.fp8:
                dq = dq_fp8[(rank + i + 1) % cp_size]
3415
3416
3417
            if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1):
                # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or
                # [sq, b, np, hn] -> [2, sq//2, b, np, hn]
3418
                dq_ = dq_.view(*dq.shape)
3419

3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
            if ctx.fp8:
                if i >= (cp_size - rank - 1) or not causal:
                    dq.copy_(dq_)
                else:
                    if ctx.qkv_format == "bshd":
                        dq[:, 0, ...].fill_(0)
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[0].fill_(0)
                        dq[1].copy_(dq_)
            elif causal:
3431
                if i > (cp_size - rank - 1):
3432
                    dq.add_(dq_)
3433
3434
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
3435
3436
                        dq.copy_(dq_)
                    else:
3437
3438
3439
3440
3441
3442
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
3443
                        elif ctx.qkv_format == "thd":
3444
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
3445
                elif i > 0:
3446
3447
3448
3449
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
3450
                    elif ctx.qkv_format == "thd":
3451
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
3452
                else:
3453
3454
3455
3456
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
3457
                    elif ctx.qkv_format == "thd":
3458
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
3459
3460
3461
3462
3463
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
3464

3465
            if attn_dbias is not None:
3466
                idx = (rank + i + 1) % cp_size
3467
                if i == (cp_size - 1) or not causal:
3468
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
3469
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
3470
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
3471
3472
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
3473
3474
3475
3476
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
3477
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
3478
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
3479
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
3480

3481
3482
3483
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
3484

3485
3486
3487
3488
3489
3490
3491
            if ctx.fp8:
                if i < cp_size - 1:
                    dkv = dkv_fp8_[(rank + i + 1) % cp_size]
                else:
                    dkv = dkv_fp8[(rank + i + 1) % cp_size]
            else:
                dkv = p2p_comm_buffers[(i + 1) % 2][1]
3492
            if ctx.use_fused_attention:
3493
                if ctx.qkv_format in ["bshd", "sbhd"]:
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
                    dkv_ = _combine_tensors([dk_, dv_], -2)
                elif ctx.qkv_format == "thd":
                    dkv_ = torch.cat(
                        (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
                    )  # pylint: disable=used-before-assignment
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
                dkv_ = dkv_.movedim(-3, 0)
                if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
                    # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(*dkv.shape)
3508

3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
            if ctx.fp8:
                if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
                    if ctx.qkv_format == "bshd":
                        dkv[:, :, 0, ...].copy_(dkv_)
                        dkv[:, :, 1, ...].fill_(0)
                    elif ctx.qkv_format == "sbhd":
                        dkv[:, 0, ...].copy_(dkv_)
                        dkv[:, 1, ...].fill_(0)
                else:
                    dkv.copy_(dkv_)
            elif causal:
3520
                if i == (cp_size - 1):
3521
                    if rank == 0:
3522
3523
3524
3525
3526
3527
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
3528
                        elif ctx.qkv_format == "thd":
3529
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
3530
3531
                    else:
                        dkv.add_(dkv_)
3532
3533
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
3534
3535
3536
3537
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
3538
                        elif ctx.qkv_format == "thd":
3539
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
3540
                    else:
3541
3542
3543
3544
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
3545
                        elif ctx.qkv_format == "thd":
3546
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
3547
3548
3549
3550
3551
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
3552
3553
3554
3555
3556
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

3557
3558
3559
3560
3561
        if ctx.fp8 and ctx.use_fused_attention:
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
                # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
                dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
3562
3563
3564
            dq = ctx.dQKV_quantizer.create_tensor_from_data(dq_fp8)
            dkv = ctx.dQKV_quantizer.create_tensor_from_data(dkv_fp8)
            dq, dkv = [x.dequantize() for x in [dq, dkv]]
3565
3566
            dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]

3567
        if causal:
3568
3569
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
3570
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
3571
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
3572
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
3573
3574
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
3575
                dq = dq.view(-1, *dq.shape[-3:])
3576
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
3577
3578
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

3579
3580
3581
        if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
            dq[cu_seqlens_q_padded[-1] :].fill_(0)
            dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
3582

3583
        if ctx.fp8 and ctx.is_input_fp8:
3584
3585
            assert torch.uint8 not in [dq.dtype, dkv.dtype]
            dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]]
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
        dk, dv = dkv[0], dkv[1]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False)
            dq, dk, dv = flash_attn_a2a_communicate(
                [dq, dk, dv],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                False,
            )
            if ctx.qkv_format == "bshd":
                dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
            elif ctx.qkv_format == "sbhd":
                dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

3604
3605
3606
        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
3607
3608
3609
3610
3611
        # converting torch.uint8 to float8tensor
        if ctx.fp8 and ctx.is_input_fp8:
            dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype)
            dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype)
            dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype)
3612
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
3613

3614
3615
3616
        return (
            None,
            dq,
3617
3618
            dk,
            dv,
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3630
            attn_dbias,
3631
3632
3633
3634
3635
            None,
            None,
            None,
            None,
            None,
3636
3637
            None,
            None,
3638
            None,
3639
        )
3640
3641


3642
3643
def get_kv_seq_info_after_all_gather(
    local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
3644
):
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
    """Compute KV sequence index range and update window size after all-gather."""
    local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
    full_seq_end_idx = max_seqlen_kv * cp_size * 2

    if window_size is None:
        window_size = (-1, 0) if causal else (-1, -1)

    if window_size[1] == -1:
        seq_end_idx = full_seq_end_idx
        window_size_right = -1
    else:
        seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
        window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx

    if window_size[0] == -1:
        seq_start_idx = 0
        window_size_left = -1
    else:
        seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
        window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx

    return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
3667
3668
3669
3670


class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    """
3671
3672
    Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
    Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
3695
3696
        cp_group,
        cp_stream,
3697
    ):
3698
        # pylint: disable=missing-function-docstring
3699
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
3700
3701
3702
3703
3704
3705
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)

3706
3707
        qkv_dtype = q.dtype

3708
3709
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
3710
        assert not padding, f"{attn_mask_type} mask type is not supported!"
3711
3712
3713
3714
3715
3716
3717
        if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
            attn_mask_type = attn_mask_type + "_bottom_right"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            use_fused_attention or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
3718

3719
        flash_attn_fwd = None
3720
3721
3722
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
3723
3724
3725
3726
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
3727
            else:
3728
3729
3730
3731
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
3732
3733
3734
3735
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
3736
                if _flash_attn_2_5_7_plus and qkv_format == "thd":
3737
                    fa_forward_kwargs["block_table"] = None
3738
3739
                if _flash_attn_2_6_0_plus:
                    fa_forward_kwargs["softcap"] = 0.0
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
3750

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        max_seqlen_q = max_seqlen_q // (2 * cp_size)
        max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
3751
3752
3753
3754
3755
        if use_fused_attention or qkv_format == "thd":
            cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
        cu_seqlens_q_padded = (
            None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // (2 * cp_size)
        )
3756

3757
3758
3759
3760
        # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
        q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
        # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
        k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
3761

3762
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3763
3764
        k_ag, _ = gather_along_first_dim(k, cp_group)
        v_ag, _ = gather_along_first_dim(v, cp_group)
3765
3766

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3767
3768
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        cp_stream.wait_stream(torch.cuda.current_stream())

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
3779
3780

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
3781
3782
3783
        kv_seq_range_per_step = [None, None]
        window_size_per_step = [None, None]
        cu_seqlens_kv_per_step = [None, None]
3784
3785
3786
3787
3788
3789
3790
3791
        out_per_step = [None, None]
        softmax_lse_per_step = [None, None]
        rng_states = [None, None]
        out = torch.empty_like(q)

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3792
3793
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3794
3795
3796
3797
3798
3799
3800
3801
3802
                    q_ = q.select(seq_dim, i).contiguous()
                    kv_seq_range_per_step[i], window_size_per_step[i] = (
                        get_kv_seq_info_after_all_gather(
                            local_seq_chunk_ids[i],
                            cp_size,
                            max_seqlen_q,
                            max_seqlen_kv,
                            window_size,
                            causal,
3803
                        )
3804
3805
3806
3807
3808
3809
                    )
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv_ = seq_end_idx - seq_start_idx
3810
3811
3812
3813
                    if use_fused_attention or qkv_format == "thd":
                        cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
                            k.shape[1], max_seqlen_kv_, k.device
                        )
3814
3815
3816
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
3817
3818
3819
3820
                    if use_fused_attention:
                        out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
                            is_training,
                            max_seqlen_q,
3821
                            max_seqlen_kv_,
3822
                            cu_seqlens_q,
3823
                            cu_seqlens_kv_per_step[i],
3824
3825
3826
                            q_,
                            k_,
                            v_,
3827
                            qkv_dtype,
3828
3829
3830
3831
3832
3833
3834
3835
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=softmax_scale,
                            dropout=dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=attn_mask_type,
                            attn_bias_type=attn_bias_type,
                            attn_bias=attn_bias,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
3836
3837
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
                            window_size=window_size_per_step[i],
3838
3839
                        )
                    else:
3840
3841
3842
3843
3844
3845
3846
3847
                        fa_forward_args_thd = []
                        if qkv_format == "thd":
                            fa_forward_args_thd = [
                                cu_seqlens_q,
                                cu_seqlens_kv_per_step[i],
                                max_seqlen_q,
                                max_seqlen_kv_,
                            ]
3848
3849
3850
3851
3852
3853
3854
                        if _use_flash_attn_3 or (
                            _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                        ):
                            fa_forward_kwargs["window_size"] = window_size_per_step[i]
                        elif _flash_attn_2_7_0_plus:
                            fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
                            fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1]
3855
3856
3857
3858
                        fa_outputs = flash_attn_fwd(
                            q_,
                            k_,
                            v_,
3859
                            *fa_forward_args_thd,
3860
3861
                            causal=causal,
                            **fa_forward_kwargs,
3862
                        )
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
                        if not _flash_attn_2_7_0_plus:
                            out_per_step[i] = fa_outputs[4]
                            softmax_lse_per_step[i] = fa_outputs[5]
                            if not _use_flash_attn_3:
                                rng_states[i] = fa_outputs[7]
                        else:
                            out_per_step[i] = fa_outputs[0]
                            softmax_lse_per_step[i] = fa_outputs[1]
                            if not _use_flash_attn_3:
                                rng_states[i] = fa_outputs[3]
3873
3874
3875
3876

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if qkv_format == "bshd":
3877
                        out[:, i - 1].copy_(out_per_step[i - 1])
3878
                    elif qkv_format == "sbhd":
3879
                        out[i - 1].copy_(out_per_step[i - 1])
3880
3881
3882
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896

        torch.cuda.current_stream().wait_stream(cp_stream)

        if use_fused_attention:
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
        else:
            out = out.view(-1, *out.shape[-2:])

        ctx.save_for_backward(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_q_padded,
3897
            *cu_seqlens_kv_per_step,
3898
3899
3900
3901
            *out_per_step,
            *softmax_lse_per_step,
            *rng_states,
        )
3902
3903

        ctx.qkv_dtype = qkv_dtype
3904
3905
        ctx.kv_seq_range_per_step = kv_seq_range_per_step
        ctx.window_size_per_step = window_size_per_step
3906
3907
3908
3909
3910
3911
3912
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_bias_type = attn_bias_type
3913
        ctx.attn_mask_type = attn_mask_type
3914
3915
        ctx.deterministic = deterministic
        ctx.use_fused_attention = use_fused_attention
3916
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
3917
3918
3919
3920
        return out

    @staticmethod
    def backward(ctx, dout):
3921
        # pylint: disable=missing-function-docstring
3922
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
3923
3924
3925
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)

3926
3927
3928
3929
3930
3931
        (*saved_tensors,) = ctx.saved_tensors
        (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
        cu_seqlens_kv_per_step = saved_tensors[5:7]
        out_per_step = saved_tensors[7:9]
        softmax_lse_per_step = saved_tensors[9:11]
        rng_states = saved_tensors[11:13]
3932
3933
        kv_seq_range_per_step = ctx.kv_seq_range_per_step
        window_size_per_step = ctx.window_size_per_step
3934

3935
        seq_dim = ctx.qkv_format.index("s")
3936
3937
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

3938
        dout = dout.view(q.shape)
3939
        dq = torch.empty_like(q)
3940
        dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
3941
3942
3943
3944
3945
3946
3947
3948
3949
3950
        dv = torch.zeros_like(dk)
        dq_per_step = [None, None]
        dk_per_step = [None, None]
        dv_per_step = [None, None]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
        # synchronize dkv update across steps
        dkv_update_done = torch.cuda.Event()

3951
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3952
3953
        k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
        v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
3954
3955

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3956
3957
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3958
3959
3960
3961
3962
3963
3964
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        ctx.cp_stream.wait_stream(torch.cuda.current_stream())
3965
3966
3967

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]

3968
        flash_attn_bwd = None
3969
3970
3971
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
3972
3973
3974
3975
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
3976
3977
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
3978
3979
3980
3981
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
3982
3983
3984
3985
3986
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
3987
3988
                if _flash_attn_2_6_0_plus:
                    fa_backward_kwargs["softcap"] = 0.0
3989
3990
3991
3992

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3993
3994
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3995
3996
3997
3998
3999
4000
4001
4002
4003
                    q_ = q.select(seq_dim, i).contiguous()
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv = seq_end_idx - seq_start_idx
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
4004
                    out_ = out_per_step[i]
4005
                    dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
4006
4007
4008
4009
                    if ctx.use_fused_attention:
                        aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
                            ctx.max_seqlen_q,
4010
                            max_seqlen_kv,
4011
                            cu_seqlens_q,
4012
                            cu_seqlens_kv_per_step[i],
4013
4014
4015
4016
4017
                            q_,
                            k_,
                            v_,
                            out_,
                            dout_,
4018
                            ctx.qkv_dtype,
4019
                            TE_DType[dout.dtype],
4020
4021
4022
                            aux_ctx_tensors,
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
4023
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
4024
4025
4026
4027
4028
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=ctx.attn_mask_type,
                            attn_bias_type=ctx.attn_bias_type,
4029
4030
                            window_size=window_size_per_step[i],
                            deterministic=ctx.deterministic,
4031
4032
4033
4034
4035
                        )
                    else:
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
4036
4037
4038
4039
4040
4041
4042
4043
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q,
                                cu_seqlens_kv_per_step[i],
                                ctx.max_seqlen_q,
                                max_seqlen_kv,
                            ]
4044
4045
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[i]
4046
4047
4048
4049
4050
                        if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size"] = window_size_per_step[i]
                        if _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
                            fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
4051
                        flash_attn_bwd(
4052
4053
4054
4055
4056
4057
4058
4059
4060
                            dout_,
                            q_,
                            k_,
                            v_,
                            out_,
                            softmax_lse_per_step[i],
                            dq_per_step[i],
                            dk_per_step[i],
                            dv_per_step[i],
4061
                            *fa_backward_args_thd,
4062
4063
                            causal="causal" in ctx.attn_mask_type,
                            **fa_backward_kwargs,
4064
4065
4066
4067
4068
                        )

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if ctx.qkv_format == "bshd":
4069
                        dq[:, i - 1].copy_(dq_per_step[i - 1])
4070
                    elif ctx.qkv_format == "sbhd":
4071
4072
4073
4074
4075
4076
                        dq[i - 1].copy_(dq_per_step[i - 1])
                    # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
                    dk_per_step[i - 1], dv_per_step[i - 1] = [
                        x.movedim(seq_dim, 0).contiguous()
                        for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
                    ]
4077
4078
4079
                    # wait until dkv update of last step is done
                    if i > 1:
                        flash_attn_streams[i - 1].wait_event(dkv_update_done)
4080
4081
4082
4083
4084
4085
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i - 1][0],
                        kv_seq_range_per_step[i - 1][1],
                    )
                    dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
                    dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
4086
4087
4088
4089
4090
                    if i < len(local_seq_chunk_ids):
                        flash_attn_streams[i - 1].record_event(dkv_update_done)

        torch.cuda.current_stream().wait_stream(ctx.cp_stream)

4091
4092
4093
4094
4095
4096
4097
        # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
        dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
        dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False)
        dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
        dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
4098
4099
4100
4101
4102
        dk = dk.view(-1, *dk.shape[-3:])
        dv = dv.view(-1, *dv.shape[-3:])
        dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
        dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)

4103
4104
4105
        dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
        dk = dk.movedim(0, seq_dim).contiguous()
        dv = dv.movedim(0, seq_dim).contiguous()
4106
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
4107
4108
4109
4110
4111
4112
4113
4114
4115
4116
4117
4118
4119
4120
4121
4122
4123
4124
4125
4126
4127
4128
4129
4130
4131
4132
4133
4134
4135
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
4156
4157
4158
4159
4160
4161
4162

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
    """
    Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
    Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
        fp8,
        fp8_meta,
        cp_group,
        cp_stream,
4163
        quantizers,
4164
    ):
4165
        # pylint: disable=missing-function-docstring
4166
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
4167
4168
4169
4170
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
4171
        qkv_dtype = q.dtype
4172
4173
4174
4175
4176
4177
4178
4179
4180
4181
4182
4183

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
        assert not padding, f"{attn_mask_type} mask type is not supported!"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            window_size == (-1, 0)
            or window_size == (-1, -1)
            or use_fused_attention
            or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
4184

4185
        flash_attn_fwd = None
4186
4187
4188
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
4189
4190
4191
4192
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
4193
4194
                fa_forward_kwargs["window_size"] = window_size
            else:
4195
4196
4197
4198
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
4199
4200
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
4201
                if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
4202
                    fa_forward_kwargs["window_size"] = window_size
4203
4204
4205
                elif _flash_attn_2_7_0_plus:
                    fa_forward_kwargs["window_size_left"] = window_size[0]
                    fa_forward_kwargs["window_size_right"] = window_size[1]
4206
4207
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
4208
                if _flash_attn_2_5_7_plus and qkv_format == "thd":
4209
                    fa_forward_kwargs["block_table"] = None
4210
4211
                if _flash_attn_2_6_0_plus:
                    fa_forward_kwargs["softcap"] = 0.0
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
4223
4224
4225

        assert (
            q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
        ), "The number of attention heads needs to be divisible by CP size!"

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        batch_dim = qkv_format.index("b")
        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

4226
        fused_attn_backend = None
4227
4228
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
4229
        is_output_fp8 = False
4230
        if fp8:
4231
4232
4233
4234
4235
4236
4237
4238
            is_output_fp8 = fp8_meta["recipe"].fp8_mha

        QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
            get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
        )
        if fp8:
            if use_fused_attention:

4239
                fused_attn_backend = FusedAttnBackend["FP8"]
4240
4241
4242
4243
4244
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
                if is_input_fp8:
4245
4246
4247
4248
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                    q_f16, k_f16, v_f16 = q, k, v
4249
                    q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]]
4250
                fp8_meta_kwargs = {}
4251
4252
                fp8_meta_kwargs["s_quantizer"] = S_quantizer
                fp8_meta_kwargs["o_quantizer"] = O_quantizer  # partial result quantizer
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True)
        q, k, v = flash_attn_a2a_communicate(
            [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
        )

4265
        if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
4266
            q_f16, k_f16, v_f16 = q, k, v
4267
            q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]]
4268
4269
4270

        batch_size = q.shape[batch_dim]
        if use_fused_attention:
4271
4272
4273
4274
4275
4276
4277
4278
4279
4280
4281
            q_part, k_part, v_part = q, k, v
            if fp8:
                q_part = QKV_quantizer.create_tensor_from_data(
                    q, fake_dtype=qkv_dtype, internal=True
                )
                k_part = QKV_quantizer.create_tensor_from_data(
                    k, fake_dtype=qkv_dtype, internal=True
                )
                v_part = QKV_quantizer.create_tensor_from_data(
                    v, fake_dtype=qkv_dtype, internal=True
                )
4282
4283
4284
4285
4286
4287
            out, aux_ctx_tensors = fused_attn_fwd(
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
4288
4289
4290
4291
                q_part,
                k_part,
                v_part,
                qkv_dtype,
4292
4293
4294
4295
4296
4297
4298
4299
4300
4301
4302
4303
                fused_attn_backend,
                attn_scale=softmax_scale,
                dropout=dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                attn_bias=attn_bias,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                window_size=window_size,
                **fp8_meta_kwargs,
            )
4304
4305
            if fp8:
                out = out._data
4306
        else:
4307
4308
4309
4310
4311
4312
4313
4314
            fa_forward_args_thd = []
            if qkv_format == "thd":
                fa_forward_args_thd = [
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                ]
4315
            fa_outputs = flash_attn_fwd(
4316
4317
4318
                q,
                k,
                v,
4319
                *fa_forward_args_thd,
4320
                causal=causal,
4321
                **fa_forward_kwargs,
4322
            )
4323
4324
4325
4326
4327
4328
            if not _flash_attn_2_7_0_plus:
                out, softmax_lse = fa_outputs[4], fa_outputs[5]
                rng_state = fa_outputs[7] if not _use_flash_attn_3 else None
            else:
                out, softmax_lse = fa_outputs[0], fa_outputs[1]
                rng_state = fa_outputs[3] if not _use_flash_attn_3 else None
4329
4330
4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
            aux_ctx_tensors = [softmax_lse, rng_state]

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False)
        out = flash_attn_a2a_communicate(
            out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
        )

        if use_fused_attention:
            if qkv_format == "bshd":
                # [b*s, np, hn] -> [b, s, np, hn]
                out = out.view(batch_size, -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                # [s*b, np, hn] -> [s, b, np, hn]
                out = out.view(-1, batch_size, *out.shape[-2:])

        if fp8:
4345
            if is_output_fp8:
4346
4347
                out_fp8 = O_quantizer.create_tensor_from_data(
                    out, fake_dtype=qkv_dtype, internal=False
4348
4349
                )
                out_ret = out_fp8
4350
                out = out_fp8._data
4351
            else:
4352
4353
                out_fp8 = O_quantizer.create_tensor_from_data(
                    out, fake_dtype=qkv_dtype, internal=False
4354
                )
4355
                out_f16 = out_fp8.dequantize()
4356
4357
4358
4359
4360
4361
4362
                out_ret = out_f16
        else:
            out_ret = out

        if fp8:
            if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                q_save, k_save, v_save, out_save = q, k, v, out
4363
            elif is_input_fp8:
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
                q_fp8 = QKV_quantizer.create_tensor_from_data(
                    q, fake_dtype=qkv_dtype, internal=False
                )
                k_fp8 = QKV_quantizer.create_tensor_from_data(
                    k, fake_dtype=qkv_dtype, internal=False
                )
                v_fp8 = QKV_quantizer.create_tensor_from_data(
                    v, fake_dtype=qkv_dtype, internal=False
                )
                q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out
4374
4375
4376
4377
4378
            else:
                q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16
        else:
            q_save, k_save, v_save, out_save = q, k, v, out

4379
        tensors_to_save, tensor_objects = prepare_for_saving(
4380
4381
4382
4383
4384
4385
4386
4387
4388
4389
            q_save,
            k_save,
            v_save,
            out_save,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *aux_ctx_tensors,
        )
4390
4391
4392
4393
4394
4395
4396
4397
4398
4399
4400
4401
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects

        ctx.qkv_dtype = qkv_dtype
        ctx.QKV_quantizer = QKV_quantizer
        ctx.O_quantizer = O_quantizer
        ctx.S_quantizer = S_quantizer
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.qkv_dtype = qkv_dtype

4402
4403
4404
4405
4406
4407
4408
4409
4410
4411
4412
4413
4414
4415
4416
        ctx.batch_size = batch_size
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_mask_type = attn_mask_type
        ctx.attn_bias_type = attn_bias_type
        ctx.deterministic = deterministic
        ctx.window_size = window_size
        ctx.use_fused_attention = use_fused_attention
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
4417
4418
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
4419
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
4420
4421
4422
4423
        return out_ret

    @staticmethod
    def backward(ctx, dout):
4424
        # pylint: disable=missing-function-docstring
4425
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
4426
4427
        cp_size = get_distributed_world_size(ctx.cp_group)

4428
4429
4430
4431
4432
4433
4434
4435
4436
4437
4438
4439
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *aux_ctx_tensors,
        ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
        dout_dtype = dout.dtype
4440
4441
4442
4443
4444

        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
        causal = "causal" in ctx.attn_mask_type
        seq_dim = ctx.qkv_format.index("s")

4445
4446
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
4447
        if ctx.fp8:
4448
4449
4450
            fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
            fused_attn_dqkv_dtype = fp8_dtype_backward

4451
4452
            if ctx.use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
4453
                if ctx.is_output_fp8:
4454
4455
4456
4457
4458
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    dout_fp8 = dout
                    dout = dout_fp8._data
                else:
                    dout_f16 = dout
4459
                    dout = ctx.dO_quantizer(dout_f16)._data
4460
                fp8_meta_kwargs = {}
4461
4462
4463
4464
                fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
                fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer
                fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer

4465
4466
4467
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
4468
            if ctx.fp8_meta is not None and ctx.is_output_fp8:
4469
                assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
4470
                q, k, v, out, dout = [x.dequantize() for x in [q, k, v, out, dout]]
4471
4472
4473
4474
4475
4476
4477
4478
4479
4480
4481
4482
4483
4484
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_dqkv_dtype = TE_DType[dout.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if not ctx.use_fused_attention:
            out = out.view(ctx.batch_size, -1, *out.shape[-2:])
        dout = dout.view(*out.shape)

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True)
        out, dout = flash_attn_a2a_communicate(
            [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
        )

4485
        flash_attn_bwd = None
4486
4487
4488
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
4489
4490
4491
4492
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
4493
4494
4495
                fa_backward_kwargs["window_size"] = ctx.window_size
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
4496
4497
4498
4499
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
4500
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
4501
                if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
4502
                    fa_backward_kwargs["window_size"] = ctx.window_size
4503
4504
4505
                elif _flash_attn_2_7_0_plus:
                    fa_backward_kwargs["window_size_left"] = ctx.window_size[0]
                    fa_backward_kwargs["window_size_right"] = ctx.window_size[1]
4506
4507
4508
4509
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
4510
4511
                if _flash_attn_2_6_0_plus:
                    fa_backward_kwargs["softcap"] = 0.0
4512
4513

        if ctx.use_fused_attention:
4514
4515
4516
4517
4518
4519
4520
4521
4522
4523
4524
4525
4526
4527
4528
4529
4530
4531
4532
4533
4534
4535
4536
            q_part = q
            k_part = k
            v_part = v
            out_part = out
            dout_part = dout

            if ctx.fp8:
                q_part = ctx.QKV_quantizer.create_tensor_from_data(
                    q_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                k_part = ctx.QKV_quantizer.create_tensor_from_data(
                    k_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                v_part = ctx.QKV_quantizer.create_tensor_from_data(
                    v_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                out_part = ctx.O_quantizer.create_tensor_from_data(
                    out_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                dout_part = ctx.dO_quantizer.create_tensor_from_data(
                    dout_part, fake_dtype=ctx.qkv_dtype, internal=True
                )

4537
4538
4539
4540
4541
            dq, dk, dv, _ = fused_attn_bwd(
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
4542
4543
4544
4545
4546
4547
                q_part,
                k_part,
                v_part,
                out_part,
                dout_part,
                ctx.qkv_dtype,
4548
4549
4550
4551
4552
4553
4554
4555
4556
4557
4558
4559
4560
4561
                fused_attn_dqkv_dtype,
                aux_ctx_tensors,
                fused_attn_backend,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                attn_scale=ctx.softmax_scale,
                dropout=ctx.dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=ctx.attn_mask_type,
                attn_bias_type=ctx.attn_bias_type,
                window_size=ctx.window_size,
                deterministic=ctx.deterministic,
                **fp8_meta_kwargs,
            )
4562
4563
4564
4565
            if ctx.fp8:
                dq = dq._data
                dk = dk._data
                dv = dv._data
4566
4567
4568
        else:
            softmax_lse, rng_state = aux_ctx_tensors
            dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
4569
4570
4571
4572
4573
4574
4575
4576
            fa_backward_args_thd = []
            if ctx.qkv_format == "thd":
                fa_backward_args_thd = [
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    ctx.max_seqlen_q,
                    ctx.max_seqlen_kv,
                ]
4577
4578
4579
            if not _use_flash_attn_3:
                fa_backward_kwargs["rng_state"] = rng_state
            flash_attn_bwd(
4580
4581
4582
4583
4584
4585
4586
4587
4588
                dout,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
4589
                *fa_backward_args_thd,
4590
4591
                causal=causal,
                **fa_backward_kwargs,
4592
4593
4594
4595
4596
4597
4598
            )

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False)
        dq, dk, dv = flash_attn_a2a_communicate(
            [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
        )

4599
        if ctx.qkv_format == "bshd":
4600
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
4601
        elif ctx.qkv_format == "sbhd":
4602
4603
4604
            dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8:
4605
4606
4607
4608
4609
            dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype)
            dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype)
            dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
            if not ctx.is_input_fp8:
                dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]]
4610
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
4611
4612
4613
4614
4615
4616
4617
4618
4619
4620
4621
4622
4623
4624
4625
4626
4627
4628
4629
4630
4631
4632
4633

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4634
4635
4636
            None,
            None,
            None,
4637
            None,
4638
4639
4640
        )


4641
def attn_forward_func_with_cp(
4642
4643
4644
4645
4646
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
4647
    cu_seqlens_kv,
4648
    max_seqlen_q,
4649
    max_seqlen_kv,
4650
4651
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
4652
4653
4654
4655
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
4656
    cp_comm_type,
4657
4658
4659
4660
4661
4662
4663
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
4664
    window_size=None,
4665
4666
    fp8=False,
    fp8_meta=None,
4667
    quantizers=None,
4668
) -> torch.Tensor:
4669
4670
4671
4672
    """
    Attention implementation with context parallelism.
    """

4673
4674
4675
4676
4677
4678
4679
4680
4681
4682
4683
4684
4685
4686
4687
4688
    if cp_comm_type == "a2a+p2p":
        assert isinstance(
            cp_group, list
        ), "Hierarchical CP implementation needs multi-level CP groups!"
        assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
        if get_distributed_world_size(cp_group[0]) == 1:
            cp_group = cp_group[1]
            cp_comm_type = "p2p"
        elif get_distributed_world_size(cp_group[1]) == 1:
            cp_group = cp_group[0]
            cp_comm_type = "a2a"
    else:
        assert isinstance(
            cp_group, dist_group_type
        ), f"Unsupported process group for CP communication type {cp_comm_type}!"

4689
4690
4691
4692
4693
4694
4695
4696
4697
4698
4699
4700
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
4701
    assert qkv_format != "thd" or (
4702
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
4703
    ), "cu_seqlens_padded cannot be None with context parallelism + THD format!"
4704
4705
4706

    sliding_window_attn = (
        window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
4707
    )
4708
4709
4710
4711
    assert not sliding_window_attn or cp_comm_type in [
        "a2a",
        "all_gather",
    ], "The context parallel running configs cannot support sliding window attetnion!"
4712

4713
4714
4715
4716
4717
4718
4719
4720
4721
4722
4723
4724
4725
4726
4727
4728
4729
4730
4731
4732
4733
    args = [
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ]

4734
    if cp_comm_type in ["p2p", "a2a+p2p"]:
4735
        args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers]
4736
4737
4738
4739
4740
4741
4742
        out = AttnFuncWithCPAndKVP2P.apply(*args)
    elif cp_comm_type == "all_gather":
        args.pop(5)
        args.pop(8)
        args += [window_size, cp_group, cp_stream]
        out = AttnFuncWithCPAndKVAllGather.apply(*args)
    elif cp_comm_type == "a2a":
4743
        args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers]
4744
        out = AttnFuncWithCPAndQKVOA2A.apply(*args)
4745
4746
4747
    else:
        raise ValueError(f"Unsupported communication type: {cp_comm_type}!")

4748
4749
4750
    return out


4751
4752
4753
4754
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
4755

4756
4757
4758
    def __init__(
        self,
        dim: int,
4759
        rotary_percent: float = 1.0,
4760
4761
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
4762
        rotary_base: float = 10000.0,
4763
4764
4765
4766
4767
4768
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
4769
4770
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
4771
4772
4773
4774
4775
4776
4777
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
4778
4779
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
4780
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
4781
        self.rotary_base = rotary_base
4782
        inv_freq = 1.0 / (
4783
            self.rotary_base
4784
4785
4786
4787
4788
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
4789
        self.register_buffer("inv_freq", inv_freq)
4790
4791
4792
4793
4794
4795
4796
4797
4798
4799
4800
4801
4802
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
4803
4804
4805
4806
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
4807

4808
4809
4810
4811
4812
4813
4814
4815
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
4816
4817
4818
4819
4820
4821
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

4822
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
4823
4824
4825
4826
4827
4828
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

4829
4830
4831
4832
4833
4834
4835
4836
4837
4838
4839
4840
4841
4842
4843
4844
4845

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
4846
4847
        cp_size: int = 1,
        cp_rank: int = 0,
4848
    ) -> torch.Tensor:
4849
        # pylint: disable=missing-function-docstring
4850
4851
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
4852
4853
4854
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
4855
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
4856
        elif tensor_format == "thd":
4857
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
4858
4859
4860
4861
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format
4862
4863
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
4864
4865
4866
4867

        return output

    @staticmethod
4868
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
4869
        # pylint: disable=missing-function-docstring
4870
4871
4872
4873
4874
4875
4876
4877
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
4878
4879
4880
            grad_input = tex.fused_rope_thd_backward(
                grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
            )
4881
4882
4883
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

4884
        return grad_input, None, None, None, None, None
4885
4886


4887
4888
4889
4890
4891
4892
4893
4894
4895
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


4896
def apply_rotary_pos_emb(
4897
4898
4899
4900
4901
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
4902
4903
    cp_size: int = 1,
    cp_rank: int = 0,
4904
) -> torch.Tensor:
4905
    """
4906
    Apply rotary positional embedding tensor to the input tensor.
4907

4908
4909
4910
    Parameters
    ----------
    t: torch.Tensor
4911
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
4912
4913
4914
4915
4916
4917
4918
4919
4920
4921
4922
4923
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
4924
4925
4926
4927
4928
        Should be `cu_seqlens_padded` when cp_size > 1.
    cp_size: int, default = 1.
        Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
    cp_rank: int, default = 0.
        Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
4929
    """
4930
4931
4932
4933
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
4934
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
4935
4936
4937
4938
4939
4940

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

4941
4942
4943
4944
4945
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
4946
4947
4948
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
4949
    freqs = freqs[:cur_seq_len]
4950
    if tensor_format == "bshd":
4951
4952
4953
4954
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
4955

4956
4957
4958
4959
4960
4961
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
4962
    t = (t * cos_) + (_rotate_half(t) * sin_)
4963
4964
4965
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
4966
class _SplitAlongDim(torch.autograd.Function):
4967
4968
4969
    """"""

    @staticmethod
4970
4971
4972
4973
4974
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
4975
        squeeze=False,
4976
    ) -> Tuple[torch.Tensor, ...]:
4977
        # pylint: disable=missing-function-docstring
cyanguwa's avatar
cyanguwa committed
4978
4979
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
4980
4981
4982
4983
4984
4985
4986
4987
4988
4989
4990
4991
4992
4993
4994
4995
4996
        if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
            mixed_x_layer, Float8Tensor
        ):
            return tuple(
                Float8TensorBase(
                    fp8_scale_inv=mixed_x_layer._scale_inv,
                    fp8_dtype=mixed_x_layer._fp8_dtype,
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
                    quantizer=mixed_x_layer._quantizer,
                )
                for x in torch.split(
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
                    dim=split_dim,
                )
            )
4997
        if isinstance(mixed_x_layer, Float8Tensor):
4998
4999
5000
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
5001
5002
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
5003
5004
                )
                for x in torch.split(
5005
5006
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
5007
5008
5009
                    dim=split_dim,
                )
            )
5010
5011
5012
5013
        out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
        if squeeze:
            out_list = [x.squeeze(split_dim) for x in out_list]
        return out_list
5014
5015

    @staticmethod
5016
    def backward(ctx, *grad_outputs):
5017
        # pylint: disable=missing-function-docstring
5018
5019
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
5020
5021
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
5022
5023
5024
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
5025
5026
5027
5028
5029
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

5030
5031
5032
5033
5034
5035
5036
5037
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
5038
5039
5040
5041
5042
5043
5044
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
5045
5046
5047
                    noop_ok = False
                    break
            if noop_ok:
5048
5049
5050
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
5051
5052
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
5053
5054
5055
5056
5057
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
5058
                )
5059
5060
5061
5062
5063
                return (
                    Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape),
                    None,
                    None,
                )
5064
5065

            grad_outputs_data = [x._data for x in grad_outputs]
5066
            data = torch.cat(grad_outputs_data, dim=split_dim)
5067
            return (
5068
5069
                Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape),
                None,
5070
5071
5072
                None,
                None,
            )
5073
5074
        noop_ok = True
        strides = grad_outputs[0].stride()
5075
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
5076
        shape = list(grad_outputs[0].shape)
5077
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
5078
5079
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
5080
5081
5082
5083
5084
5085
5086
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
5087
5088
5089
                noop_ok = False
                break
        if noop_ok:
5090
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
5091
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
5092
            new_shape[split_dim] = sum(split_sizes)
5093
5094
5095
5096
5097
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
5098
            )
cyanguwa's avatar
cyanguwa committed
5099
            return ret, None, None
5100

5101
        return torch.cat(grad_outputs, dim=split_dim), None, None
5102
5103
5104
5105
5106
5107
5108
5109
5110


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
5111
        softmax_scale: float,
5112
        attention_type: str = "self",
5113
5114
5115
5116
5117
5118
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

5119
        self.softmax_scale = softmax_scale
5120
        self.attention_type = attention_type
5121
5122
5123
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

5124
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
5125
5126
5127
5128
5129
5130

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

5131
5132
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
5133
5134
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
5135

5136
5137
5138
5139
5140
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5141
        qkv_layout: str = "sbh3d",
5142
5143
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
5144
        attn_mask_type: str = "causal",
5145
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5146
        window_size: Optional[Tuple[int, int]] = None,
5147
5148
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
5149
        alibi_slopes: Optional[torch.Tensor] = None,
5150
    ) -> torch.Tensor:
5151
        """Unfused attention fprop"""
5152
5153
5154
5155
5156
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
5157
            # convert to sbhd and use sbhd implementation for now
5158
5159
5160
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
5161
5162
5163
5164
5165
        batch_size, max_seqlen_q, max_seqlen_kv = (
            query_layer.shape[1],
            query_layer.shape[0],
            key_layer.shape[0],
        )
5166
5167
5168
5169
5170
5171
5172
5173
5174

        attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
            max_seqlen_q,
            max_seqlen_kv,
            attn_mask_type=attn_mask_type,
            attention_mask=attention_mask,
            window_size=window_size,
            attention_type=self.attention_type,
        )
5175

5176
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
5177
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
5178
5179
5180
5181
5182
5183
5184
5185
5186

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

5187
        if key_layer.shape[2] != query_layer.shape[2]:
5188
5189
5190
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
5191
            key_layer = key_layer.repeat_interleave(
5192
5193
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
5194
            value_layer = value_layer.repeat_interleave(
5195
5196
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
5197

5198
        # [sq, b, np, hn] -> [sq, b * np, hn]
5199
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
5200
5201
5202
5203
5204
5205
5206
5207
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
5208
            dtype=query_layer.dtype,
5209
5210
5211
            device=torch.cuda.current_device(),
        )

5212
        scale = self.softmax_scale
5213
        if apply_qk_layer_scaling:
5214
            scale /= self.layer_number
5215
5216

        # Raw attention scores. [b * np, sq, sk]
5217
5218
5219
5220
5221
5222
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
5223
                alpha=scale,
5224
            ).view(*output_size)
5225
5226
5227
5228
5229
5230
5231

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
5232
            matmul_result = matmul_result.view(*output_size) + core_attention_bias
5233
            matmul_result *= scale
5234

5235
5236
5237
5238
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
5239
                _, core_attention_bias = get_alibi(
5240
5241
5242
                    output_size[1],
                    output_size[2],
                    output_size[3],
5243
5244
                    actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
                    actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
5245
5246
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
5247
                )
5248
5249
5250
5251
5252
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
5253
                alpha=scale,
5254
            )
5255
5256
            matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
                dtype=query_layer.dtype
5257
            )
5258
5259
5260

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
5261
        attention_probs = self.scale_mask_softmax(
5262
            matmul_result, attention_mask, attn_mask_type, softmax_scale
5263
        )
5264

5265
5266
5267
5268
5269
        # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
        # the columns (pad tokens from k) are already zeroed out during softmax
        if "padding" in attn_mask_type:
            attention_probs = attention_probs.masked_fill(attention_mask, 0)

5270
5271
5272
5273
5274
5275
5276
5277
5278
5279
5280
5281
5282
5283
5284
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
5285
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
5286
5287

        # change view [b * np, sq, sk]
5288
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
5289
5290
5291
5292
5293
5294
5295

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

5296
        if qkv_format == "sbhd":
5297
5298
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
5299

5300
5301
5302
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

5303
        if qkv_format == "bshd":
5304
5305
5306
5307
5308
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
5309
5310
5311
5312
5313
5314

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
5315
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
5316
5317

    @staticmethod
5318
5319
5320
5321
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
5322
        value_layer: torch.Tensor,
5323
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
5324
        # pylint: disable=missing-function-docstring
5325
5326
5327
5328
5329
5330
5331
5332
5333
5334
5335
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
5336
5337
5338
5339
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
5340
        dv: torch.Tensor,
5341
    ) -> Tuple[Union[torch.Tensor, None], ...]:
5342
        # pylint: disable=missing-function-docstring
5343
5344
5345
5346
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

5347

5348
def get_qkv_layout(
5349
5350
5351
5352
5353
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
5354
    """Get qkv layout.
5355

5356
5357
5358
5359
5360
5361
5362
5363
5364
5365
5366
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
5367
        `d` head size, and `t` the total number of tokens in a batch, i.e.
5368
5369
5370
5371
5372
5373
5374
5375
5376
5377
5378
5379
5380
5381
5382
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
5383
5384
5385
5386
5387
5388
5389
5390
5391
    q: torch.Tensor
        Query tensor. It may be different from input `q` as we try to fit tensors to
        a supported layout.
    k: torch.Tensor
        Key tensor. It may be different from input `k` as we try to fit tensors to
        a supported layout.
    v: torch.Tensor
        Value tensor. It may be different from input `v` as we try to fit tensors to
        a supported layout.
5392
    """
5393

5394
5395
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
5396

5397
    def run_iteratively(q, k, v):
5398
        # check data pointers
5399
5400
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
5401
        check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
5402
5403
5404
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

5405
5406
5407
5408
5409
5410
5411
        # check tensor shapes
        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = shape[:-1] == v.shape[:-1]

        # check tensor strides
5412
5413
        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
5414
5415
        check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
            sv / v.shape[-1] for sv in v.stride()[:-1]
5416
        )
5417

5418
5419
5420
5421
5422
5423
        # check tensor offsets for h3d and 3hd layouts
        prod_h_d = q.shape[-1] * q.shape[-2]
        check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v]))
        check_h3d_offsets = all(
            x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v])
        )
5424

5425
5426
5427
5428
5429
5430
        # check tensor offsets for hd_h2d and hd_2hd layouts
        prod_all_dims = [np.prod(x.shape) for x in [q, k]]
        offset = prod_all_dims[0] if check_ptrs_qkv else 0
        prod_h_d = k.shape[-1] * k.shape[-2]
        check_2hd_offsets = all(
            x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v])
5431
        )
5432
5433
        check_h2d_offsets = all(
            x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v])
5434
        )
5435

5436
5437
5438
5439
5440
5441
5442
5443
5444
5445
        # check tensor offsets for hd_hd_hd layouts
        check_hd_offsets_qkv = (
            all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v]))
            if check_ptrs_qkv
            else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v]))
        )
        check_hd_offsets_qk = (
            all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k]))
            if not check_ptrs_qkv and check_ptrs_qk
            else all(x.storage_offset() == 0 for i, x in enumerate([q, k]))
5446
        )
5447
5448
5449
5450
        check_hd_offsets_kv = (
            all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v]))
            if not check_ptrs_qkv and check_ptrs_kv
            else all(x.storage_offset() == 0 for i, x in enumerate([k, v]))
5451
        )
5452

5453
        if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
5454
            # sb3hd, bs3hd, t3hd
5455
            # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv
5456
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
5457
        elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
5458
            # sbh3d, bsh3d, th3d
5459
            # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv
5460
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
5461
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
5462
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
5463
5464
5465
            # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv
            # q and kv may be disjoint or consecutive in memory, and when consecutive, they may
            # have the same data pointer, i.e. check_ptrs_qkv=True
5466
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
5467
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
5468
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
5469
5470
5471
            # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv
            # q and kv may be disjoint or consecutive in memory, and when consecutive, they may
            # have the same data pointer, i.e. check_ptrs_qkv=True
5472
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
5473
5474
5475
5476
5477
        elif (
            check_strides_kv
            and check_shapes_kv
            and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk)
        ):
5478
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
5479
5480
5481
            # three chunks of memory, q, k and v, which may be disjoint or consecutive, and
            # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
            # check_ptrs_qk=True or check_ptrs_kv=True
5482
            qkv_layout = "_".join(list([qkv_format]) * 3)
5483
        else:
5484
            qkv_layout = "not_supported"
5485
5486
5487
5488

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
5489
    if qkv_layout == "not_supported":
5490
5491
5492
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
5493
    if qkv_layout == "not_supported":
5494
        raise RuntimeError("The provided qkv memory layout is not supported!")
5495

5496
    return qkv_layout, q, k, v
5497

5498

5499
def check_set_window_size(
5500
5501
5502
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
5503
5504
5505
5506
5507
5508
5509
5510
    """Check if sliding window size is compliant with attention mask type.
    If not, set it to the appropriate size.

         attn_mask_type                              |   window_size
    -------------------------------------------------------------------------
    no_mask, padding, arbitrary                      | (-1, -1) or (>=0, >=0)
    causal, padding_causal                           | (-1,  0) or (>=0, 0)
    causal_bottom_right, padding_causal_bottom_right | (-1,  0) or (>=0, 0)
5511
    """
5512
    orig_window_size = window_size
5513
    if "causal" in attn_mask_type:
5514
        if orig_window_size is None:
5515
            window_size = (-1, 0)
5516
5517
5518
        elif orig_window_size == (-1, -1) or (
            orig_window_size[0] >= 0 and orig_window_size[1] != 0
        ):
5519
5520
5521
5522
            window_size = (orig_window_size[0], 0)
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
5523
        elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
5524
5525
5526
5527
            assert False, (
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
    elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
5528
5529
5530
        if orig_window_size is None:
            window_size = (-1, -1)
        elif orig_window_size == (-1, 0):
5531
            window_size = (-1, -1)
5532
5533
5534
            warnings.warn(
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
5535
        elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
5536
5537
5538
5539
5540
            assert False, (
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
    else:
        assert False, "Invalid attn_mask_type: " + attn_mask_type
5541
    return window_size
5542

5543

5544
class FlashAttention(torch.nn.Module):
5545
    """Dot product attention, using HazyResearch flash-attn package:
5546
    https://github.com/Dao-AILab/flash-attention
5547
5548
5549
5550
    """

    def __init__(
        self,
5551
        softmax_scale: float,
5552
5553
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
5554
5555
        attention_type: str = "self",
        layer_number: Optional[int] = None,
5556
        deterministic: bool = False,
5557
5558
5559
    ) -> None:
        super().__init__()

5560
5561
5562
5563
5564
5565
5566
        if _flash_attn_is_installed:
            assert (
                _flash_attn_version >= _flash_attn_version_required
            ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
            assert (
                _flash_attn_version <= _flash_attn_max_version
            ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
5567

5568
        self.softmax_scale = softmax_scale
5569
5570
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
5571
5572
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
5573
        self.deterministic = deterministic
5574
5575
5576
5577
        self.logger = logging.getLogger("FlashAttention")
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
5578
5579
5580
5581
5582
5583

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5584
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5585
5586
5587
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5588
5589
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5590
        attn_mask_type: str = "causal",
5591
        window_size: Optional[Tuple[int, int]] = None,
5592
        alibi_slopes: Optional[torch.Tensor] = None,
5593
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
5594
        cp_global_ranks: List[int] = None,
5595
        cp_stream: torch.cuda.Stream = None,
5596
        cp_comm_type: str = "p2p",
5597
5598
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
5599
        quantizers=None,
5600
5601
5602
    ) -> torch.Tensor:
        """flash-attn fprop"""

5603
5604
5605
5606
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
5607
5608
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
5609
        ), "FlashAttention currently only supports CUDA tensors."
5610
5611
        assert (
            qkv_layout in QKVLayouts
5612
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
5613

5614
5615
5616
5617
5618
5619
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
5620
        context_parallel = cp_size > 1
5621

5622
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
5623

5624
5625
5626
5627
5628
5629
5630
5631
5632
5633
5634
5635
5636
        if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
            if qkv_format == "sbhd":
                # For now just 128, will make it more general in the future
                if (
                    query_layer.shape[-1] == 128
                    and query_layer.shape[0] * query_layer.shape[1] >= 512
                    and qkv_layout == "sbh3d"
                ):
                    query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                        query_layer, key_layer, value_layer
                    )
                else:
                    query_layer, key_layer, value_layer = [
5637
                        x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
5638
                    ]
5639
            if context_parallel:
5640
                query_layer, key_layer, value_layer = [
5641
5642
5643
5644
5645
                    x.contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        else:
            if qkv_format == "sbhd":
                query_layer._data, key_layer._data, value_layer._data = [
5646
                    x.transpose(0, 1)
5647
5648
                    for x in (query_layer._data, key_layer._data, value_layer._data)
                ]
5649
                query_layer, key_layer, value_layer = [
5650
                    Float8Tensor.make_like(x, data=x._data, shape=x._data.shape)
5651
5652
                    for x in (query_layer, key_layer, value_layer)
                ]
5653
            if context_parallel:
5654
5655
                query_layer._data, key_layer._data, value_layer._data = [
                    x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
5656
                ]
5657

5658
        batch_size = query_layer.shape[0]
5659

5660
        if qkv_format in ["sbhd", "bshd"]:
5661
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
5662
5663
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
5664
5665
5666

            if "padding" in attn_mask_type:
                assert not context_parallel, "Padding mask not supported with context parallelism!"
5667
5668
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
5669
                    x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
5670
5671
5672
5673
5674
5675
5676
                    for x in [query_layer, key_layer, value_layer]
                ]

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
5677
                    if cu_seqlens_q is None:
5678
5679
5680
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
5681
5682
5683
5684
5685
5686
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
5687
5688
                    )
                else:
5689
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
5690
5691
5692
5693
5694
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
5695
5696
5697
5698
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
5699
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
5700
            else:
5701
5702
5703
5704
5705
5706
5707
5708
5709
5710
5711
5712
5713
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
5714
5715
5716
5717
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
5718
5719
5720
5721
5722
5723
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
5724

5725
5726
5727
        if context_parallel and all(
            not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
        ):
5728
5729
5730
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
5731
            with self.attention_dropout_ctx():
5732
                output = attn_forward_func_with_cp(
5733
5734
5735
5736
5737
5738
5739
5740
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
5741
5742
                    cu_seqlens_q if qkv_format == "thd" else None,
                    cu_seqlens_kv if qkv_format == "thd" else None,
5743
                    self.attention_dropout if self.training else 0.0,
5744
5745
5746
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
5747
                    cp_comm_type,
5748
                    softmax_scale=self.softmax_scale,
5749
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
5750
                    attn_mask_type=attn_mask_type,
5751
                    deterministic=self.deterministic,
5752
                    window_size=window_size,
5753
                    quantizers=quantizers,
5754
5755
                )
        else:
5756
5757

            from .cpu_offload import CPUOffloadEnabled
5758

5759
5760
5761
5762
5763
5764
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

5765
            with self.attention_dropout_ctx():
5766
                fa_optional_forward_kwargs = {}
5767
5768
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
5769
5770
5771
5772
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
5773
5774
5775
5776
                fa_optional_forward_args_thd = []
                if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
                    func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
                else:
5777
5778
                    if _flash_attn_2_5_7_plus:
                        fa_optional_forward_kwargs["block_table"] = None
5779
5780
5781
5782
5783
5784
5785
5786
5787
5788
                    func = (
                        flash_attn_varlen_func
                        if not _use_flash_attn_3
                        else flash_attn_varlen_func_v3
                    )
                    fa_optional_forward_args_thd.append(cu_seqlens_q)
                    fa_optional_forward_args_thd.append(cu_seqlens_kv)
                    fa_optional_forward_args_thd.append(max_seqlen_q)
                    fa_optional_forward_args_thd.append(max_seqlen_kv)
                if _use_flash_attn_3:
5789
5790
5791
                    fa_3_optional_forward_kwargs = {}
                    fa_3_optional_forward_kwargs["window_size"] = window_size
                    fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
5792
                    if fp8:
5793
                        QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
5794
                        torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
5795
                        torch_orig_dtype = query_layer.dtype
5796
5797
5798
5799
5800
5801
5802
5803
5804
5805
5806

                        def convert_to_torch_float8(tensor, dtype):
                            out = torch.Tensor().to(device=tensor.device, dtype=dtype)
                            out.set_(
                                tensor._data.untyped_storage(),
                                tensor._data.storage_offset(),
                                tensor._data.shape,
                                tensor._data.stride(),
                            )
                            return out

5807
5808
5809
5810
5811
                        # "fp8_mha" decides outputs in fp8, while inputs are inferred from
                        # the real dtype
                        assert isinstance(key_layer, query_layer.__class__) and isinstance(
                            value_layer, query_layer.__class__
                        ), "q, k, and v must have the same type."
5812
                        if not isinstance(query_layer, Float8Tensor):
5813
                            query_layer, key_layer, value_layer = (
5814
                                QKV_quantizer(x) for x in [query_layer, key_layer, value_layer]
5815
                            )
5816
5817
                        fa_3_optional_forward_kwargs["descale_q"] = (
                            query_layer._scale_inv.unsqueeze(0)
5818
                        )
5819
5820
                        fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze(
                            0
5821
                        )
5822
5823
                        fa_3_optional_forward_kwargs["descale_v"] = (
                            value_layer._scale_inv.unsqueeze(0)
5824
                        )
5825
5826
5827
                        query_layer, key_layer, value_layer = (
                            convert_to_torch_float8(x, torch_dtype)
                            for x in [query_layer, key_layer, value_layer]
5828
                        )
5829
5830
5831
5832
5833
5834
5835
5836
5837
5838
5839
5840
5841
5842
5843
5844
5845
5846
5847
5848
5849
5850
5851
5852
5853
                    try:
                        output, _ = func(
                            query_layer,
                            key_layer,
                            value_layer,
                            *fa_optional_forward_args_thd,
                            softmax_scale=self.softmax_scale,
                            causal="causal" in attn_mask_type,
                            **fa_3_optional_forward_kwargs,
                        )
                    except TypeError as e:
                        if _flash_attn_3_0_0_beta:
                            e.args = (
                                e.args[0]
                                + ". Please update your flash-attn v3 (beta) installation as it "
                                + "may have added more supported arguments to its API. \n"
                                + _flash_attn_3_installation_steps,
                            ) + e.args[1:]
                        raise

                    if fp8:
                        output = output.to(dtype=torch_orig_dtype)
                    if fp8 and fp8_meta["recipe"].fp8_mha:
                        O_quantizer = quantizers["scaling_fwd"][META_O]
                        output = O_quantizer(output)
5854
                else:
5855
5856
5857
5858
5859
5860
5861
5862
5863
                    output = func(
                        query_layer,
                        key_layer,
                        value_layer,
                        *fa_optional_forward_args_thd,
                        self.attention_dropout if self.training else 0.0,
                        softmax_scale=self.softmax_scale,
                        causal="causal" in attn_mask_type,
                        **fa_optional_forward_kwargs,
5864
                    )
5865

5866
5867
5868
5869
5870
5871
5872
5873
5874
5875
5876
5877
5878
5879
5880
5881
5882
5883
5884
5885
5886
5887
5888
5889
5890
5891
5892
5893
5894
5895
5896
5897
5898
5899
5900
5901
5902
5903
5904
5905
5906
5907
5908
5909
5910
5911
5912
5913
5914
5915
5916
5917
5918
5919
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)

        if qkv_format == "sbhd":
            # (bs)hd -> bs(hd) -> sb(hd)
            if fp8 and fp8_meta["recipe"].fp8_mha:
                output_data = (
                    output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
                    .transpose(0, 1)
                    .contiguous()
                )
                output = Float8Tensor.make_like(
                    output,
                    data=output_data,
                    shape=output_data.shape,
                )
            else:
                output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
        elif qkv_format == "bshd":
            # (bs)hd -> bs(hd)
            output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
        elif qkv_format == "thd":
            # thd -> t(hd)
            output = output.reshape(output.shape[0], -1)

        return output.contiguous()


def _combine_tensors(
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    if isinstance(tensors[0], Float8Tensor):
        new_stride = list(tensors[0]._data.stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape)
    else:
        new_stride = list(tensors[0].stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
        combined_tensor.set_(
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
5920
5921
        )

5922
5923
    return combined_tensor

5924

5925
5926
5927
5928
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
5929
5930
5931
5932
5933
5934
5935
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
5936
5937
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
5938
5939
5940
5941
5942
5943
5944
5945
5946
5947
5948
        q,
        k,
        v,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
5949
        window_size,
5950
5951
5952
5953
5954
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
5955
        quantizers,
5956
        deterministic,
5957
    ):
5958
        # pylint: disable=missing-function-docstring
5959
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
5960
        is_input_fp8 = False
5961
5962
5963
5964
5965
5966
        is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
        fake_dtype = q.dtype

        QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
            get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
        )
5967
5968
        if fp8:
            fused_attention_backend = FusedAttnBackend["FP8"]
5969
5970
5971
            assert isinstance(k, q.__class__) and isinstance(
                v, q.__class__
            ), "q, k, and v must have the same type."
5972

5973
            is_input_fp8 = isinstance(q, Float8Tensor)
5974
            q_fp8, k_fp8, v_fp8 = None, None, None
5975
            if is_input_fp8:
5976
                q_fp8, k_fp8, v_fp8 = q, k, v
5977
5978
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
5979
                qkv_group = len(qkv_layout.split("_"))
5980
5981
5982
5983
5984
5985
5986
5987
5988
5989
5990
5991
5992
5993
5994
5995
5996
5997
5998
5999
                match qkv_group:
                    case 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                        qkv_fp8 = QKV_quantizer(qkv)
                        q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True)
                    case 2:
                        q_fp8 = QKV_quantizer(q)
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                        kv_fp8 = QKV_quantizer(kv_c)
                        k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1], True)
                    case 3:
                        q_fp8 = QKV_quantizer(q)
                        k_fp8 = QKV_quantizer(k)
                        v_fp8 = QKV_quantizer(v)
                    case _:
                        raise "Invalid qkv_layout " + qkv_layout
6000
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
6001
6002
6003
6004
6005
6006
6007
6008
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
6009
                fake_dtype,
6010
6011
                fused_attention_backend,
                attn_bias,
6012
6013
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6014
6015
                S_quantizer,
                O_quantizer,
6016
6017
6018
6019
6020
6021
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6022
                window_size,
6023
6024
                rng_gen,
            )
6025
            if is_output_fp8:
6026
                out_ret = out_fp8
6027
            else:
6028
                out_ret = out_fp8.dequantize().view(out_fp8.shape)
6029
6030
            out_save = out_ret

6031
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
6032
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6033
6034
6035
6036
6037
6038
                if is_input_fp8:
                    qkv_group = len(qkv_layout.split("_"))
                    if qkv_group == 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
6039
6040
                        qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape)
                        q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
6041
                    if qkv_group == 2:
6042
                        q = q.dequantize()
6043
6044
6045
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
6046
6047
                        kv_no_fp8 = kv.dequantize()
                        k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True)
6048
                    if qkv_group == 3:
6049
6050
6051
                        q = q.dequantize()
                        k = k.dequantize()
                        v = v.dequantize()
6052
                if is_output_fp8:
6053
6054
6055
                    out_save = out_fp8.dequantize()

            fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
6056
        else:
6057

6058
            out_ret, aux_ctx_tensors = fused_attn_fwd(
6059
6060
6061
6062
6063
6064
6065
6066
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
6067
                fake_dtype,
6068
6069
                fused_attention_backend,
                attn_bias,
6070
6071
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6072
6073
                None,  # s_quantizer
                None,  # o_quantizer
6074
6075
6076
6077
6078
6079
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6080
                window_size,
6081
6082
                rng_gen,
            )
6083
            out_save = out_ret
6084
            fp8_tensors = (None, None, None, None)
6085

6086
6087
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

6088
        from .cpu_offload import CPUOffloadEnabled
6089

6090
        if CPUOffloadEnabled:
6091
6092
6093
6094
6095
6096
6097
            if ctx.fp8:
                tensor_list = fp8_tensors
            else:
                tensor_list = [q, k, v, out_save]

            tensor_list.extend(aux_ctx_tensors)

6098
            qkv_layout = "sbhd_sbhd_sbhd"
6099
6100
6101
6102
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

6103
6104
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
6105
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
6106
6107
        tensors_to_save, tensor_objects = prepare_for_saving(
            *fp8_tensors,
6108
6109
6110
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
6111
6112
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6113
6114
            *aux_ctx_tensors,
        )
6115
6116
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
6117
        ctx.fp8_meta = fp8_meta
6118
6119
6120
6121
6122
6123

        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = S_quantizer

6124
6125
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
6126
        ctx.fake_dtype = fake_dtype
6127
6128
6129
6130
6131
6132
6133
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
6134
        ctx.window_size = window_size
6135
        ctx.fused_attention_backend = (
6136
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
6137
        )
6138
        ctx.use_FAv2_bwd = use_FAv2_bwd
6139
        ctx.deterministic = deterministic
6140

6141
        return out_ret
6142
6143
6144

    @staticmethod
    def backward(ctx, d_out):
6145
        # pylint: disable=missing-function-docstring
6146
        if ctx.is_output_fp8:
6147
6148
6149
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
6150

6151
        d_out = d_out.contiguous()
6152
        (
6153
6154
6155
6156
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
6157
6158
6159
6160
6161
6162
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
6163
6164
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6165
6166
6167
6168
6169
            *other_tensors,
        ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)

        aux_ctx_tensors = other_tensors

6170
6171
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
6172
        rest = [None]
6173
        if ctx.use_FAv2_bwd:
6174
            softmax_lse, rng_state = aux_ctx_tensors
6175
6176
6177
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
6178
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
6179
            flash_attn_cuda_bwd(
6180
6181
6182
6183
6184
6185
6186
6187
6188
6189
6190
6191
6192
6193
6194
6195
6196
6197
6198
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
6199
            )
6200
6201
6202
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
6203
        else:
6204
6205
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
6206
                    if ctx.is_output_fp8:
6207
6208
                        d_out_fp8 = d_out
                    else:
6209
                        d_out_fp8 = ctx.dO_quantizer(d_out)
6210
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
6211
6212
6213
6214
6215
6216
6217
6218
6219
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
6220
6221
                        ctx.fake_dtype,
                        ctx.qkv_dtype,
6222
                        aux_ctx_tensors,
6223
                        ctx.fused_attention_backend,
6224
6225
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6226
6227
6228
                        ctx.S_quantizer,
                        ctx.dP_quantizer,
                        ctx.dQKV_quantizer,
6229
6230
6231
6232
6233
6234
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6235
6236
                        ctx.window_size,
                        ctx.deterministic,
6237
                    )
6238

6239
                    if not ctx.is_input_fp8:
6240
                        qkv_group = len(ctx.qkv_layout.split("_"))
6241
                        if qkv_group == 1:
6242
                            dim = ctx.qkv_layout.find("3")
6243
6244
                            dqkv_fp8_data = _combine_tensors(
                                [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim
6245
                            )
6246
6247
6248
6249
6250
                            dqkv_fp8 = dq_fp8.make_like(
                                tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape
                            )
                            dqkv = dqkv_fp8.dequantize()
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True)
6251
                        if qkv_group == 2:
6252
                            dq = dq_fp8.dequantize()
6253
6254
6255
6256
6257
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
6258
6259
                            dkv = dkv_c_fp8.dequantize()
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1], True)
6260
                        if qkv_group == 3:
6261
6262
6263
6264
6265
                            dq = dq_fp8.dequantize()
                            dk = dk_fp8.dequantize()
                            dv = dv_fp8.dequantize()
                    else:
                        dq, dk, dv = dq_fp8, dk_fp8, dv_fp8
6266
                else:
6267
6268
                    if isinstance(d_out, QuantizedTensor):
                        d_out = d_out.dequantize()
6269
                    dq, dk, dv, *rest = fused_attn_bwd(
6270
6271
6272
6273
6274
6275
6276
6277
6278
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
6279
                        ctx.fake_dtype,
6280
6281
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
6282
                        ctx.fused_attention_backend,
6283
6284
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6285
6286
6287
6288
6289
6290
6291
6292
6293
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6294
6295
                        ctx.window_size,
                        ctx.deterministic,
6296
                    )
6297

6298
6299
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
6300
6301
6302
6303
6304
6305
6306
6307
6308
6309
6310
6311
6312
6313
6314
6315
6316
6317
6318
6319
6320
6321
6322
6323
6324
6325
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
6326
6327
                None,
                None,
6328
                None,
6329
            )
6330
        # else, return (dqkv, dbias)
6331
6332
6333
6334
6335
6336
6337
6338
6339
6340
6341
6342
6343
6344
6345
6346
6347
6348
6349
6350
6351
6352
6353
6354
6355
6356
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
6357
6358
            None,
            None,
6359
            None,
6360
        )
6361

6362

6363
class FusedAttention(torch.nn.Module):
6364
6365
6366
6367
6368
6369
6370
6371
6372
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

6373
6374
6375
6376
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
6377
    | attn_type     | self/cross              | self/cross                     |
6378
    | qkv_layout    |                         |                                |
6379
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
6380
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
6381
6382
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
6383
6384
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
6385
    | dropout       | yes                     | yes                            |
6386
6387
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
6388
    | output dtype  | fp16/bf16               | fp16/bf16                      |
6389
6390
6391
6392
    """

    def __init__(
        self,
6393
        softmax_scale: float,
6394
6395
6396
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
6397
6398
        layer_number: Optional[int] = None,
        deterministic: bool = False,
6399
6400
6401
    ) -> None:
        super().__init__()

6402
        self.softmax_scale = softmax_scale
6403
6404
6405
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
6406
6407
6408
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
6409
        self.layer_number = 1 if layer_number is None else layer_number
6410
        self.deterministic = deterministic
6411

6412
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
6413
6414
            """
            Temporarily remove fused_attention._extra_state as a missing key
6415
            or an unexpected key when loading Transformer Engine checkpoints.
6416
6417
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
6418
            phased out in Transformer Engine 2.0.
6419
6420
            """
            for key in incompatible_keys.missing_keys:
6421
                if "fused_attention._extra_state" in key:
6422
                    incompatible_keys.missing_keys.remove(key)
6423
6424
6425
6426
6427
6428
6429
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
6430

6431
6432
        self.register_load_state_dict_post_hook(remove_extra_states_check)

6433
    @no_torch_dynamo()
6434
6435
6436
6437
6438
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
6439
6440
6441
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
6442
6443
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
6444
6445
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
6446
        attn_mask_type: str = "causal",
6447
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6448
        window_size: Optional[Tuple[int, int]] = None,
6449
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
6450
6451
6452
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
6453
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
6454
6455
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
6456
        cp_comm_type: str = "p2p",
6457
6458
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
6459
        quantizers=None,
6460
6461
    ) -> torch.Tensor:
        """fused attention fprop"""
6462
6463
6464
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
6465
6466
6467
6468
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
6469
6470
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
6471
        ), "FusedAttention only supports CUDA tensors."
6472
6473
        assert (
            qkv_layout in QKVLayouts
6474
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
6475

6476
6477
6478
6479
6480
6481
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
6482
        context_parallel = cp_size > 1
6483

6484
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
6485

6486
6487
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
6488
                batch_size, max_seqlen_q, max_seqlen_kv = (
6489
6490
6491
6492
6493
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
6494
                batch_size, max_seqlen_q, max_seqlen_kv = (
6495
6496
6497
6498
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
6499
6500
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
6501
            if "padding" in attn_mask_type:
6502
6503
                assert not context_parallel, "Padding mask not supported with context parallelism!"

6504
6505
6506
6507
6508
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
6509
                    if self.attention_type == "self":
6510
6511
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
6512
                    else:
6513
6514
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
6515
            else:
6516
6517
6518
6519
6520
6521
6522
6523
6524
6525
6526
6527
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
6528
6529
6530
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
6531
6532
6533
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
6534
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
6535

6536
        if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None):
6537
6538
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
6539
6540
6541

        qkv_dtype = TE_DType[query_layer.dtype]

6542
6543
6544
6545
6546
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
6547

6548
6549
6550
6551
6552
6553
6554
6555
6556
6557
6558
        if fp8:
            assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                " is required for FP8 attention!"
            )
            assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
            assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
                "Amax reduction across TP+CP group is necessary when using context parallelism with"
                " FP8!"
            )

6559
        if context_parallel:
6560
            assert (
6561
6562
                fp8
                or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
6563
6564
6565
6566
6567
6568
6569
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
6570
6571
6572
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
6573
6574
6575
6576
6577
6578
6579
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
6580
6581
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
6582
                    self.attention_dropout if self.training else 0.0,
6583
6584
6585
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
6586
                    cp_comm_type,
6587
                    softmax_scale=self.softmax_scale,
6588
                    qkv_format=qkv_format,
6589
                    attn_mask_type=attn_mask_type,
6590
6591
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
6592
                    deterministic=self.deterministic,
6593
                    use_fused_attention=True,
6594
                    window_size=window_size,
6595
6596
                    fp8=fp8,
                    fp8_meta=fp8_meta,
6597
                    quantizers=quantizers,
6598
6599
                )
        else:
6600
6601
6602
6603
6604
6605
6606
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
6607
6608
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
6609
6610
6611
6612
6613
6614
6615
6616
6617
6618
6619
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
6620
                    window_size,
6621
6622
6623
6624
6625
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
6626
                    quantizers,
6627
                    self.deterministic,
6628
                )
6629

6630
6631
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
6632
6633


6634
class DotProductAttention(TransformerEngineBaseModule):
6635
6636
6637
6638
6639
6640
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

6641
        Argument :attr:`attention_mask` in the `forward` call is only used when
6642
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
6643
6644
6645

    .. warning::

6646
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
6647
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
6648
6649
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
6650

6651
6652
6653
6654
6655
6656
6657
    .. note::

        Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
        As the FP8 attention support expands from one backend to multiple backends, the location
        of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).


6658
6659
6660
6661
    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
6662
6663
6664
    kv_channels : Union[int, Tuple[int, int]]
                the head size in key and value tensors. If the same, :attr:`kv_channels` can be
                an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
6665
6666
6667
6668
6669
6670
6671
6672
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
6673
6674
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
6675
    attn_mask_type: str, default = `causal`
6676
                   type of attention mask passed into softmax operation, options are "`no_mask`",
6677
6678
6679
6680
6681
6682
6683
6684
6685
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
6686
                   "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
6687
6688
6689
6690
6691
6692
6693
6694
6695
6696
6697
6698
6699
6700
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
6701
6702
6703
6704
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
6705
6706
6707
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
6708
                be overridden by :attr:`window_size` in `forward` as well.
6709
6710
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
6711
6712
6713
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
6714
6715
6716
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
6717
               `h` the number of heads, `d` head size, and `t` the total number of tokens
6718
6719
6720
6721
6722
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
6723
               For that, please use `get_qkv_layout` to gain the layout information.
6724
6725
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
6726
                `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
6727
6728
6729
6730
6731
6732
6733
6734
6735

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
6736
    cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
6737
              context parallel process group.
6738
6739
6740
              ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
              List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
              and cp_group[1] are for a2a and p2p communications respectively.
6741
6742
6743
6744
6745
6746
6747
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
6748
    cp_comm_type : str, default = `p2p`
6749
                  inter-gpu communication type for context parallelism.
6750
                  Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
6751
6752
6753
6754
6755
6756
                  "p2p": Exchange KV chunks with P2P communications in ring topology.
                         P2P is async and can be overlapped with attention compute.
                  "all_gather": All-gather to get full sequence of KV before attention.
                                The all-gather is not async, and cannot be overlapped.
                  "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                         group, and gather to get full sequence of QKV.
6757
6758
6759
                  "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                  across each CP sub-group (e.g., via NVLink), then exchanging KV with
                  p2p between sub-groups (e.g., via IBLink).
6760
6761
6762
6763
6764
    """

    def __init__(
        self,
        num_attention_heads: int,
6765
        kv_channels: Union[int, Tuple[int, int]],
6766
        num_gqa_groups: Optional[int] = None,
6767
        attention_dropout: float = 0.0,
6768
        qkv_format: str = "sbhd",
6769
        attn_mask_type: str = "causal",
6770
        window_size: Optional[Tuple[int, int]] = None,
6771
6772
6773
6774
6775
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
6776
        attention_type: str = "self",
6777
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
6778
        cp_global_ranks: List[int] = None,
6779
        cp_stream: torch.cuda.Stream = None,
6780
        cp_comm_type: str = "p2p",
6781
        softmax_scale: Optional[float] = None,
6782
6783
6784
    ) -> None:
        super().__init__()

6785
        self.logger = logging.getLogger("DotProductAttention")
6786
6787
6788
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
6789
        self.qkv_format = qkv_format
6790
        attn_mask_type = attn_mask_type.replace(",", "_")
6791
6792
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
6793
        self.attn_mask_type = attn_mask_type
6794
        self.window_size = check_set_window_size(attn_mask_type, window_size)
6795
6796
6797
6798
6799
6800
6801
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
6802
        self.get_rng_state_tracker = get_rng_state_tracker
6803
        self.num_attention_heads = num_attention_heads
6804
        self.layer_number = 1 if layer_number is None else layer_number
6805
6806
6807
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
6808
        self.cp_comm_type = cp_comm_type
6809

6810
6811
6812
6813
6814
6815
        self.hidden_size_per_attention_head_k = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[0]
        )
        self.hidden_size_per_attention_head_v = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[1]
        )
6816

6817
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
6818
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
6819

6820
6821
6822
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
6823

6824
        self.rng_states_tracker = None
6825
6826
6827
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
6828
6829
6830
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
6831

6832
        if softmax_scale is None:
6833
6834
6835
            softmax_scale = 1.0 / math.sqrt(
                kv_channels if isinstance(kv_channels, int) else kv_channels[0]
            )
6836

6837
6838
6839
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
6840
        )
6841
6842
6843
6844
6845
6846
6847
6848
6849
6850
6851
6852
6853
6854
6855
6856
6857
6858
6859
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
6860

6861
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
6862
6863
6864
6865

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

6866
6867
6868
6869
6870
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

6871
6872
6873
6874
6875
6876
6877
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
6878

6879
        # Instantiating three types since use of flash-attn and FusedAttention
6880
        # might be ruled out due to forward inputs.
6881
6882
6883
6884
6885
6886
6887
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
6888

6889
        self.unfused_attention = UnfusedDotProductAttention(
6890
6891
6892
6893
            softmax_scale,
            attention_type=attention_type,
            **attn_kwargs,
            layer_number=layer_number,
6894
        )
6895

6896
6897
6898
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
6899
6900
            when loading older Transformer Engine checkpoints. Will phase out
            this hook in Transformer Engine 2.0.
6901
6902
6903
6904
6905
6906
6907
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

6908
6909
6910
6911
6912
6913
6914
6915
6916
6917
6918
6919
6920
6921
6922
6923
6924
6925
6926
6927
6928
6929
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        """
        This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
        metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
        `core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
        <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
        """
        fused_attn_key = False
        dot_product_attn_key = False
        for k in state_dict.keys():
            if "core_attention.fused_attention._extra_state" in k:
                fused_attn_key = True
            if "core_attention._extra_state" in k:
                dot_product_attn_key = True
        if fused_attn_key and not dot_product_attn_key:
            prefix = prefix + "fused_attention."
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )

6930
6931
6932
6933
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
6934
        **forward_kwargs: Dict[str, Any],
6935
6936
6937
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

6938
6939
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
6940
6941
6942

        hidden_states = checkpoint(
            custom_forward,
6943
6944
6945
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
6946
            *forward_args,
6947
            **forward_kwargs,
6948
6949
6950
6951
        )

        return hidden_states

6952
6953
    def set_context_parallel_group(
        self,
6954
        cp_group: Union[dist_group_type, List[dist_group_type], None],
6955
6956
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
6957
        cp_comm_type: str = "p2p",
6958
    ) -> None:
6959
6960
6961
6962
6963
6964
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
6965
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
6966
                  context parallel process group.
6967
6968
6969
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
6970
6971
6972
6973
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
6974
        cp_comm_type : str, default = `p2p`
6975
                      inter-gpu communication type for context parallelism.
6976
                      Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
6977
6978
6979
6980
6981
6982
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
6983
6984
6985
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
6986
        """
6987
6988
6989
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
6990
        self.cp_comm_type = cp_comm_type
6991

6992
    @no_torch_dynamo(recursive=False)
6993
6994
6995
6996
6997
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
6998
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6999
7000
7001
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
7002
7003
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
7004
7005
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
7006
        attn_mask_type: Optional[str] = None,
7007
        window_size: Optional[Tuple[int, int]] = None,
7008
        checkpoint_core_attention: bool = False,
7009
7010
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
7011
        alibi_slopes: Optional[torch.Tensor] = None,
7012
        fast_zero_fill: bool = True,
7013
        inference_params: Optional[InferenceParams] = None,
7014
7015
7016
7017
7018
7019
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

7020
7021
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
7022

7023
7024
        .. note::

7025
7026
7027
7028
7029
7030
7031
7032
7033
7034
7035
7036
7037
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
7038
            and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
7039
7040
7041
7042
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
7043
7044
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
7045
            optimizations in FusedAttention. When unset, Transformer Engine determines the code path
7046
7047
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
7048

7049
7050
7051
7052
7053
7054
7055
7056
7057
7058
7059
7060
7061
7062
7063
7064
7065
7066
7067
7068
7069
7070
7071
7072
7073
7074
7075
7076
7077
7078
7079
7080
7081
7082
7083
7084
7085
7086
7087
7088
7089
7090
7091
7092
7093
7094
7095
7096
7097
7098
7099
7100
7101
7102
        .. note::
            .. _cu_seqlens note:

            When training data has variable sequence lengths, users have two options.

            1. Manipulate the data and pad all sequences to the same length. Use
               :attr:`qkv_format` = {"bshd", "sbhd"} and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
               (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
               the real sequence length information. For example, a batch of 3 sequences
               [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative
               sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

            2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
               as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed
               without any padding, and the sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

               In certain use cases, a varying number of identifier tokens are inserted between
               sequences. These tokens do not participate in the attention calculation.
               :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
               in such cases to correctly identify the start and end of each sequence in a batch.
               For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and
               :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13]
               for self-attention.

        .. note::
            .. _max_seqlen note:

            When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch.
            :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of
            :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will
            infer them as such.

            When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and
            :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch.
            When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`.
            This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this
            overhead, users are recommended to obtain the maximum sequence lengths from the data loaders
            and pass them in.

            - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch,
              dynamic shapes need to be supported for tensor construction. FlashAttention and
              UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static
              to create graphs before performance heuristics analysis. To reduce the number of graphs created
              per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size,
              :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of
              :attr:`query_layer`, "t" dimension of :attr:`key_layer`}.

7103
7104
7105
7106
7107
7108
7109
7110
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
7111
7112
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
7113
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
7114
7115
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
7116
7117
7118
7119
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
7120
7121
7122
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
7123
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
7124
                   with shape [batch_size + 1] and dtype torch.int32.
7125
                   See :ref:`note<cu_seqlens note>` for more details.
7126
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
7127
7128
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
7129
                   See :ref:`note<cu_seqlens note>` for more details.
7130
7131
7132
7133
7134
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
7135
                   See :ref:`note<cu_seqlens note>` for more details.
7136
7137
7138
7139
7140
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
7141
                   See :ref:`note<cu_seqlens note>` for more details.
7142
7143
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
7144
                      See :ref:`note<max_seqlen note>` for more details.
7145
7146
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
7147
                       See :ref:`note<max_seqlen note>` for more details.
7148
7149
7150
7151
7152
7153
7154
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
7155
        window_size: Optional[Tuple[int, int]], default = `None`
7156
                    Sliding window size for local attention.
7157
7158
7159
7160
7161
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
7162
        core_attention_bias_type: str, default = `no_bias`
7163
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
7164
        core_attention_bias: Optional[torch.Tensor], default = `None`
7165
7166
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
7167
7168
7169
7170
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
7171
        fast_zero_fill: bool, default = `True`
7172
                    Whether to use the fast path to set output tensors to 0 or not.
7173
7174
7175
7176
7177
7178
7179
7180
7181
7182
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
7183
        """
7184

7185
7186
7187
7188
7189
7190
7191
7192
7193
        with self.prepare_forward(
            query_layer,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:
            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
7194
                        self.logger.warning(
7195
7196
7197
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
7198
7199
7200
7201
7202
7203
7204
7205
7206
7207
7208

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
7209

7210
7211
7212
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
7213
7214
7215
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
7216
7217
7218
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
7219
7220
7221
7222
7223
7224
7225
7226
            assert (
                key_layer.shape[-1] == self.hidden_size_per_attention_head_k
            ), f"Keys have head_dim = {key_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
            assert (
                value_layer.shape[-1] == self.hidden_size_per_attention_head_v
            ), f"Values have head_dim = {value_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
7227

7228
7229
7230
            if qkv_format is None:
                qkv_format = self.qkv_format

7231
7232
7233
7234
7235
7236
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
7237
            assert (
7238
7239
7240
7241
7242
7243
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
7244

7245
7246
7247
7248
            if window_size is None:
                window_size = self.window_size
            window_size = check_set_window_size(attn_mask_type, window_size)

7249
7250
7251
7252
7253
7254
7255
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
7256

7257
7258
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
7259

7260
7261
7262
7263
7264
                # convert causal to causal_bottom_right in inference when KV-caching is in use
                # so users can run with the same attn_mask_type for training and inference
                if attn_mask_type in ["causal", "padding_causal"]:
                    attn_mask_type = attn_mask_type + "_bottom_right"

7265
7266
7267
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
7268

7269
7270
7271
7272
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
7273

7274
7275
7276
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
7277

7278
7279
7280
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
7281

7282
7283
7284
7285
7286
7287
7288
7289
7290
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
7291

7292
7293
7294
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
7295

7296
7297
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
7298
7299

            assert (
7300
7301
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
7302
7303
7304
7305
            ), (
                "Keys and values must have num_gqa_group ="
                f" {self.num_gqa_groups_per_partition} heads!"
            )
7306
7307
7308
7309
7310
7311
7312
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
7313
                assert all(
7314
7315
7316
7317
7318
7319
7320
7321
7322
7323
7324
7325
7326
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
7327
                batch_size = len(cu_seqlens_q) - 1
7328
                if max_seqlen_q is None:
7329
7330
7331
7332
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
7333
                    max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
7334
                if max_seqlen_kv is None:
7335
7336
7337
7338
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
7339
                    max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
7340

7341
7342
7343
7344
7345
7346
            cp_size = 1
            if isinstance(self.cp_group, dist_group_type):
                cp_size = get_distributed_world_size(self.cp_group)
            elif isinstance(self.cp_group, list):
                for group in self.cp_group:
                    cp_size *= get_distributed_world_size(group)
7347
7348
            context_parallel = cp_size > 1

7349
            if qkv_format in ["sbhd", "bshd"]:
7350
                assert all(
7351
7352
7353
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
7354
7355
                    max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
7356
                    batch_size = query_layer.shape[1]
7357
                else:
7358
7359
                    max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
7360
                    batch_size = query_layer.shape[0]
7361
7362
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
7363
7364
7365
7366
7367
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
7368
                        the sequence dimension in 'query_layer'!"""
7369
7370
7371
7372
7373
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
7374
                        the sequence dimension in 'key_layer' and 'value_layer'!"""
7375
7376
7377
7378
7379
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
7380
                        if self.attention_type == "self":
7381
7382
7383
7384
7385
7386
7387
7388
7389
7390
7391
7392
7393
7394
7395
7396
                            cu_seqlens_q = get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                    else:
                        cu_seqlens_q = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
                        cu_seqlens_kv = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
7397

7398
7399
7400
7401
7402
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
7403
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
7404
7405
7406
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
7407
                qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
7408
7409
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
7410

7411
7412
7413
7414
7415
7416
7417
7418
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
7419
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
7420
7421
7422
7423
7424
7425
7426
7427
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
7428
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
7429
7430
7431
7432
7433
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

7434
7435
            core_attention_bias_shape = None
            if core_attention_bias is not None:
7436
                if (
7437
7438
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
7439
                ):
7440
7441
7442
7443
7444
7445
7446
7447
7448
7449
7450
7451
7452
7453
7454
7455
7456
7457
7458
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

            pad_between_seqs = (
                cu_seqlens_q_padded is not None
7459
                and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
7460
7461
            ) or (
                cu_seqlens_kv_padded is not None
7462
                and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
7463
            )
7464

7465
            attention_params = AttentionParams(
7466
7467
7468
7469
7470
7471
7472
7473
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
7474
7475
                head_dim_qk=query_layer.shape[-1],
                head_dim_v=value_layer.shape[-1],
7476
7477
7478
7479
7480
7481
7482
7483
7484
7485
7486
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
7487
7488
                deterministic=self.deterministic,
                is_training=self.training,
7489
7490
7491
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
7492
            global _attention_backends, _use_flash_attn_3
7493
7494
7495
7496
7497
7498
7499
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
7500
                _use_flash_attn_3 = _flash_attn_3_is_installed
7501
7502
7503
7504
7505
7506
7507
7508
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
                ) = get_attention_backend(attention_params)
                if use_flash_attention:
7509
7510
                    self.logger.info(
                        "Running with FlashAttention backend (version %s)",
7511
                        _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version,
7512
                    )
7513
7514
7515
7516
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
7517
                    )
7518
7519
7520
7521
7522
7523
7524
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
7525

7526
7527
7528
7529
7530
7531
7532
7533
7534
7535
7536
7537
7538
7539
7540
7541
7542
7543
7544
7545
7546
7547
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
7548
                    cp_comm_type=self.cp_comm_type,
7549
7550
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
7551
7552
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
7553
                    quantizers=self.quantizers,
7554
                )
7555

7556
            if use_fused_attention:
7557
7558
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
7559
7560
7561
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
7562
7563
7564
7565
7566
7567
7568
                    fu_core_attention_bias_type = "post_scale_bias"
                    _, fu_core_attention_bias = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
7569
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
7570
                    )
7571
7572
7573
7574
7575
7576
7577
7578
7579
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
7580
7581
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
7582
7583
7584
7585
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
7586
                        window_size=window_size,
7587
7588
7589
7590
7591
7592
7593
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
7594
                        cp_comm_type=self.cp_comm_type,
7595
7596
7597
7598
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
                    )
                return self.fused_attention(
7599
7600
7601
7602
7603
7604
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
7605
7606
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
7607
7608
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
7609
7610
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
7611
                    window_size=window_size,
7612
                    fused_attention_backend=fused_attention_backend,
7613
7614
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
7615
7616
7617
7618
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
7619
                    cp_comm_type=self.cp_comm_type,
7620
7621
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
7622
                    quantizers=self.quantizers,
7623
                )
7624

7625
            from .cpu_offload import CPUOffloadEnabled
7626

7627
7628
7629
7630
7631
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
7632

7633
7634
7635
7636
7637
7638
7639
7640
7641
7642
7643
7644
            if use_unfused_attention:
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
7645
                        window_size=window_size,
7646
7647
7648
7649
7650
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
7651
7652
7653
                    query_layer,
                    key_layer,
                    value_layer,
7654
7655
7656
7657
7658
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
7659
                    window_size=window_size,
7660
7661
7662
7663
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
7664

7665
            raise ValueError("No dot product attention support for the provided inputs!")
7666
7667


7668
7669
7670
7671
7672
7673
7674
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

7675
7676
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
7677

7678
7679
7680
7681
7682
7683
7684
7685
7686
7687
7688
7689
7690
7691
7692
7693
7694
7695
7696
7697
7698
7699
7700
7701
7702
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
7703
7704
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
7705
                   default = `causal`
7706
7707
7708
7709
7710
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
7711
7712
7713
7714
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
7715
7716
7717
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
7718
                be overridden by :attr:`window_size` in `forward` as well.
7719
7720
7721
7722
7723
7724
7725
7726
7727
7728
7729
7730
7731
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
7732
7733
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
7734
7735
7736
7737
7738
7739
7740
7741
7742
7743
7744
7745
7746
7747
7748
7749
7750
7751
7752
7753
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
7754
          The device on which the parameters of the model will be allocated. It is the user's
7755
7756
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
7757
7758
7759
7760
7761
7762
7763
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
7764
            For that, please use `get_qkv_layout` to gain the layout information.
7765
7766
7767
7768
7769
7770
7771
7772
7773
7774
7775
7776
7777
7778
7779
7780
7781
7782
7783
7784
7785
7786
7787
7788
7789
7790
7791
7792
7793
7794
7795
7796
7797
7798
7799
7800
7801
7802
7803
7804

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
7805
7806
7807
7808
7809
7810
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
7811
7812
7813
7814
7815
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
7816
        layer_number: Optional[int] = None,
7817
        attn_mask_type: str = "causal",
7818
        window_size: Optional[Tuple[int, int]] = None,
7819
7820
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
7821
        num_gqa_groups: Optional[int] = None,
7822
7823
7824
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
7825
        params_dtype: Optional[torch.dtype] = None,
7826
        return_bias: bool = False,
7827
7828
7829
7830
7831
7832
7833
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
7834
        ub_overlap_ag: bool = False,
7835
7836
7837
7838
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
7839
        bias: bool = True,
7840
        normalization: str = "LayerNorm",
7841
        device: Union[torch.device, str] = "cuda",
7842
        qkv_format: str = "sbhd",
7843
7844
    ) -> None:
        super().__init__()
7845

7846
        self.qkv_format = qkv_format
7847
        self.attn_mask_type = attn_mask_type
7848
        self.window_size = check_set_window_size(attn_mask_type, window_size)
7849
        self.layer_number = layer_number
7850
7851
7852
7853
7854
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
7855
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
7856
        self.num_attention_heads = num_attention_heads
7857
        self.return_bias = return_bias
7858
7859
        self.cp_size = 1
        self.cp_rank = 0
7860
7861
7862
7863
7864
7865
7866

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
7867
7868
7869
7870
7871

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

7872
7873
7874
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
7875
7876
7877
7878
7879
7880

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
7881
7882
7883
7884
7885
7886
7887
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
7888
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
7889
7890
7891
7892

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
7893
7894
7895
7896
7897
7898
7899

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
7900
            "params_dtype": self.params_dtype,
7901
            "device": device,
7902
7903
7904
7905
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
7906
        if self.attention_type == "self":
7907
7908
            parameters_split = None
            if not fuse_qkv_params:
7909
7910
7911
7912
7913
7914
7915
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
7916
7917
7918
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
7919
                    self.hidden_size_q + 2 * self.hidden_size_kv,
7920
7921
7922
7923
7924
7925
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
7926
                    parameters_split=parameters_split,
7927
7928
7929
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
7930
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
7931
                    ub_overlap_ag=ub_overlap_ag,
7932
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
7933
                    ub_name="qkv",
7934
7935
7936
7937
7938
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
7939
                    self.hidden_size_q + 2 * self.hidden_size_kv,
7940
7941
7942
7943
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
7944
                    parameters_split=parameters_split,
7945
7946
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
7947
        elif self.attention_type == "cross":
7948
7949
7950
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
7951
                    self.hidden_size_q,
7952
7953
7954
7955
7956
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
7957
                    parameters_split=("query",) if not fuse_qkv_params else None,
7958
7959
7960
7961
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
7962
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
7963
                    ub_overlap_ag=ub_overlap_ag,
7964
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
7965
                    ub_name="qkv",
7966
7967
7968
7969
7970
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
7971
                    self.hidden_size_q,
7972
7973
7974
7975
7976
7977
7978
7979
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
7980
                2 * self.hidden_size_kv,
7981
7982
7983
7984
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
7985
                parameters_split=("key", "value") if not fuse_qkv_params else None,
7986
7987
7988
7989
7990
7991
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
7992
            self.hidden_size_per_attention_head,
7993
7994
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
7995
            qkv_format=self.qkv_format,
7996
7997
7998
7999
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
8000
            layer_number=self.layer_number,
8001
            attention_type=self.attention_type,
8002
8003
8004
8005
        )

        # Linear
        self.proj = Linear(
8006
            self.hidden_size_q,
8007
8008
8009
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
8010
            return_bias=return_bias,
8011
            parallel_mode="row" if set_parallel_mode else None,
8012
8013
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8014
            ub_name="proj",
8015
8016
8017
8018
            **common_gemm_kwargs,
        )

    def _allocate_memory(
8019
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
8020
    ) -> torch.Tensor:
8021
        """Allocates memory for KV cache."""
8022
8023
8024
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
8025
            self.num_gqa_groups_per_partition,
8026
            self.hidden_size_per_attention_head,
8027
            dtype=dtype,
8028
8029
8030
8031
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
8032
8033
8034
8035
8036
8037
8038
8039
8040
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
8041
8042
        self.tp_group = tp_group

8043
    def set_context_parallel_group(
8044
        self,
8045
        cp_group: Union[dist_group_type, List[dist_group_type], None],
8046
        cp_global_ranks: List[int],
8047
        cp_stream: torch.cuda.Stream,
8048
        cp_comm_type: str = "p2p",
8049
    ) -> None:
8050
8051
8052
8053
8054
8055
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
8056
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
8057
                  context parallel process group.
8058
8059
8060
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
8061
8062
8063
8064
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
8065
        cp_comm_type : str, default = `p2p`
8066
                      inter-gpu communication type for context parallelism.
8067
                      Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
8068
8069
8070
8071
8072
8073
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
8074
8075
8076
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
8077
        """
8078
8079
8080
8081
8082
8083
8084
8085
8086
8087
8088
8089
8090
8091
8092
        if isinstance(cp_group, dist_group_type):
            self.cp_size = get_distributed_world_size(cp_group)
            self.cp_rank = get_distributed_rank(cp_group)
        elif isinstance(cp_group, list):
            assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
            assert (
                cp_comm_type == "a2a+p2p"
            ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
            cp_size_a2a = get_distributed_world_size(cp_group[0])
            cp_rank_a2a = get_distributed_rank(cp_group[0])
            cp_size_p2p = get_distributed_world_size(cp_group[1])
            cp_rank_p2p = get_distributed_rank(cp_group[1])
            self.cp_size = cp_size_a2a * cp_size_p2p
            self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a

8093
8094
8095
8096
8097
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
8098
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
8099

8100
8101
8102
    def forward(
        self,
        hidden_states: torch.Tensor,
8103
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8104
        encoder_output: Optional[torch.Tensor] = None,
8105
        attn_mask_type: Optional[str] = None,
8106
        window_size: Optional[Tuple[int, int]] = None,
8107
8108
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
8109
        inference_params: Optional[InferenceParams] = None,
8110
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8111
8112
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
8113
        alibi_slopes: Optional[torch.Tensor] = None,
8114
8115
8116
8117
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
8118
        fast_zero_fill: bool = True,
8119
    ) -> Tuple[Union[torch.Tensor, None], ...]:
8120
8121
8122
8123
8124
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

8125
8126
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
8127
8128
8129
8130
8131

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
8132
8133
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
8134
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
8135
8136
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
8137
8138
8139
8140
8141
8142
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
8143
                       default = `None`
8144
8145
8146
8147
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
8148
8149
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
8150
8151
8152
8153
8154
8155
8156
8157
8158
8159
8160
8161
8162
8163
8164
8165
8166
8167
8168
8169
8170
8171
8172
8173
8174
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
8175
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
8176
        core_attention_bias: Optional[torch.Tensor], default = `None`
8177
8178
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
8179
8180
8181
8182
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
8183
8184
8185
8186
8187
8188
8189
8190
8191
8192
8193
8194
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
8195
8196
8197
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
8198
8199
        # hidden_states: [sq, b, h]

8200
        if attn_mask_type is None:
8201
            attn_mask_type = self.attn_mask_type
8202
8203
        if window_size is None:
            window_size = self.window_size
8204
        window_size = check_set_window_size(attn_mask_type, window_size)
8205

8206
        if "padding" in attn_mask_type and attention_mask is not None:
8207
8208
            for mask in attention_mask:
                assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
8209

8210
8211
8212
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
8213

8214
        # =================================================
8215
        # Pre-allocate memory for key-values for inference
8216
8217
8218
        # =================================================

        if inference_params and self.layer_number is not None:
8219
8220
8221
            assert (
                self.qkv_format != "thd"
            ), "qkv_format == thd is not supported for an inference with KV-cache!"
8222
            if self.layer_number not in inference_params.key_value_memory_dict:
8223
                inf_max_seq_len = inference_params.max_sequence_length
8224
8225
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
8226
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8227
8228
                )
                inference_value_memory = self._allocate_memory(
8229
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8230
8231
8232
8233
8234
8235
8236
8237
8238
8239
8240
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

8241
        # ======================
8242
        # Query, Key, and Value
8243
        # ======================
8244

8245
8246
8247
8248
8249
        fp8_mha = (
            FP8GlobalStateManager.is_fp8_enabled()
            and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
        )

8250
        layernorm_output = None
cyanguwa's avatar
cyanguwa committed
8251
8252
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
8253
8254
8255
8256
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8257
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8258
8259
8260
8261
8262
8263
8264
8265
8266
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8267
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8268
8269
                )

8270
8271
8272
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
8273
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
8274
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
8275
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
8276
8277
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
8278
8279
8280
8281
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
8282
8283
8284
8285
8286
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
8287
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
8288
8289
8290
                )
                # split along third last dimension
                split_dim = -3
8291
8292
8293

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
8294
8295
8296
8297
8298
8299
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
8300
8301
8302
            query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
            )
cyanguwa's avatar
cyanguwa committed
8303

8304
8305
8306
8307
8308
8309
8310
8311
8312
8313
8314
8315
            if self.qkv_format == "thd":
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
            else:
                # query: -> [sq, b, np, hn]
                # key, value: -> [sq, b, ng, hn]
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
cyanguwa's avatar
cyanguwa committed
8316
8317
        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
8318
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
8319
                encoder_output,
8320
                is_first_microbatch=is_first_microbatch,
8321
                fp8_output=fp8_mha and rotary_pos_emb is None,
8322
8323
8324
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
8325
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
8326
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
8327
                    self.num_gqa_groups_per_partition,
8328
8329
8330
8331
8332
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
8333
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
8334
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
8335
                    2 * self.num_gqa_groups_per_partition,
8336
8337
8338
8339
8340
8341
8342
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
8343
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
8344
8345
8346
8347
8348
            key_layer, value_layer = _SplitAlongDim.apply(
                mixed_kv_layer,
                split_dim,
                mixed_kv_layer.shape[split_dim] // 2,
            )
8349
8350
8351
8352
8353
8354
8355
8356
8357
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
8358
8359
8360
8361
8362
8363

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8364
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8365
8366
8367
8368
8369
8370
8371
8372
8373
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8374
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8375
8376
8377
8378
8379
8380
8381
8382
8383
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

8384
8385
8386
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
8387

8388
        if rotary_pos_emb is not None:
8389
8390
8391
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
8392
            # duplicate the pos_emb for self attention
8393
            if not isinstance(rotary_pos_emb, tuple):
8394
                rotary_pos_emb = (rotary_pos_emb,) * 2
8395
8396

            q_pos_emb, k_pos_emb = rotary_pos_emb
8397
8398
8399
8400
8401
8402
8403

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)
8404
8405
                else:
                    raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.")
8406
8407
8408
8409
8410
8411
8412

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

8413
8414
8415
8416
8417
8418
8419
8420
8421
8422
8423
8424
8425
8426
8427
8428
8429
8430
            query_layer = apply_rotary_pos_emb(
                query_layer,
                q_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_q,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
            key_layer = apply_rotary_pos_emb(
                key_layer,
                k_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_kv,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
8431

8432
8433
8434
8435
        # ===========================
        # Core attention computation
        # ===========================

8436
8437
8438
8439
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
8440
            qkv_format=self.qkv_format,
8441
8442
8443
8444
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
8445
8446
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
8447
            window_size=window_size,
8448
8449
8450
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
8451
            alibi_slopes=alibi_slopes,
8452
            fast_zero_fill=fast_zero_fill,
8453
            inference_params=inference_params,
8454
8455
        )

8456
        # ===================
8457
        # Output. [sq, b, h]
8458
        # ===================
8459
        projection_output = self.proj(
8460
8461
            context_layer,
            is_first_microbatch=is_first_microbatch,
8462
            fp8_grad=isinstance(context_layer, QuantizedTensor),
8463
8464
        )

8465
8466
8467
8468
8469
8470
8471
8472
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
8473
        if self.input_layernorm and self.return_layernorm_output:
8474
8475
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]