attention.py 387 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
from importlib.metadata import PackageNotFoundError
10
import math
11
import os
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
import warnings
14
import logging
15
import functools
16

17
from dataclasses import dataclass, fields
cyanguwa's avatar
cyanguwa committed
18
import numpy as np
19
from packaging.version import Version as PkgVersion
20
21

import torch
22
import torch.nn.functional as F
23

24
import transformer_engine_torch as tex
25
import transformer_engine as te
26
27
28
29
30
from transformer_engine.pytorch.utils import (
    get_cudnn_version,
    nvtx_range_pop,
    nvtx_range_push,
)
31
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
32
33
    fused_attn_fwd,
    fused_attn_bwd,
34
35
36
37
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
38
39
40
41
42
43
44
45
46
47
48
49
50
    META_QKV,
    META_DQKV,
    META_O,
    META_DO,
    META_S,
    META_DP,
    META_O_CP,
    META_DQKV_CP,
)
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    get_fp8_te_dtype,
    get_fp8_torch_dtype,
51
)
52
from transformer_engine.pytorch.float8_tensor import Float8Tensor
53
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
54
from transformer_engine.pytorch.module import LayerNormLinear, Linear
55
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
56
57
58
59
60
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
61
    get_default_init_method,
62
63
64
65
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
66
    AttnBiasTypes,
67
    QKVLayouts,
68
    dist_group_type,
69
    TE_DType,
70
71
72
73
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
74
    get_distributed_rank,
75
    checkpoint,
76
77
78
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
79
80
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
81
)
82
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
83
from transformer_engine.pytorch.graph import is_graph_capturing
84
85
86
87
88
from transformer_engine.pytorch.tensor.quantized_tensor import (
    QuantizedTensor,
    prepare_for_saving,
    restore_from_saved,
)
89

90

91
92
93
94
95
96
97
98
99
100
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
101
fa_logger = logging.getLogger(__name__)
102
103
104
105
106
107
108
109
110
fa_logger.setLevel(_log_level)
if not fa_logger.hasHandlers():
    fa_logger.addHandler(_stream_handler)


@functools.lru_cache(maxsize=None)
def _get_supported_versions(version_min, version_max):
    return ">= " + str(version_min) + ", " + "<= " + str(version_max)

111

112
113
114
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
115
116
117
118
119

# Detect flash-attn v2 in the environment
_flash_attn_is_installed = False
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
120
_flash_attn_version_required_blackwell = PkgVersion("2.7.3")
121
_flash_attn_max_version = PkgVersion("2.7.4.post1")
122
123
124
125
126
127
128
_flash_attn_2_plus = False
_flash_attn_2_1_plus = False
_flash_attn_2_3_plus = False
_flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False
129
_flash_attn_2_7_0_plus = False
130

131
flash_attn_cuda_bwd = None
132
133
flash_attn_func = None
flash_attn_varlen_func = None
134
135
136
137
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
138

139
140
141
try:
    _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
142
    if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
143
144
        fa_logger.debug(
            "flash-attn v2 is not installed. To use, please install it by"
145
            """ "pip3 install flash-attn".""",
146
147
        )
else:
148
149
150
151
152
153
154
    if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
        if _flash_attn_version_required_blackwell <= _flash_attn_version <= _flash_attn_max_version:
            _flash_attn_is_installed = True
    elif _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
        _flash_attn_is_installed = True

    if _flash_attn_is_installed:
155
        from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
156
        from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
157
158
        from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
        from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd
159
        from flash_attn.flash_attn_interface import (
160
            _flash_attn_varlen_forward as _flash_attn_varlen_fwd,
161
162
        )
        from flash_attn.flash_attn_interface import (
163
            _flash_attn_varlen_backward as _flash_attn_varlen_bwd,
164
165
166
167
168
169
170
171
172
        )

        _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
        _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
        _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
        _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
        _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
        _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
        _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
173
        _flash_attn_2_7_0_plus = _flash_attn_version >= PkgVersion("2.7.0")
174
175
176
    elif (
        torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN
    ):
177
178
179
        fa_logger.warning(
            "Supported flash-attn versions are %s. Found flash-attn %s.",
            _get_supported_versions(
180
181
182
183
184
                (
                    _flash_attn_version_required
                    if get_device_compute_capability() < (10, 0)
                    else _flash_attn_version_required_blackwell
                ),
185
186
187
188
189
190
191
192
193
194
195
                _flash_attn_max_version,
            ),
            _flash_attn_version,
        )

# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
_flash_attn_3_is_installed = False
_flash_attn_3_version = PkgVersion("0")
_flash_attn_3_0_0_beta = False
196
_use_flash_attn_3 = False
197
198
# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved
# https://github.com/Dao-AILab/flash-attention/issues/1452
199
_flash_attn_3_installation_steps = """\
200
201
(1) pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python3 -c "import site; print(site.getsitepackages()[0])"`
202
(3) mkdir -p $python_path/flashattn_hopper
203
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py"""
204
try:
205
    _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
206
except PackageNotFoundError:
207
    if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
208
209
        fa_logger.debug(
            "flash-attn v3 is not installed. To use, please install it by \n%s",
210
            _flash_attn_3_installation_steps,
211
        )
212
213
214
215
216
else:
    from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
    from flashattn_hopper.flash_attn_interface import (
        flash_attn_varlen_func as flash_attn_varlen_func_v3,
    )
217
218
    from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
    from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
219
    from flashattn_hopper.flash_attn_interface import (
220
        _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
221
222
    )
    from flashattn_hopper.flash_attn_interface import (
223
        _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
224
    )
225

226
227
    _flash_attn_3_is_installed = True
    _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
228
    _use_flash_attn_3 = True
229

230
231
232
233
234
235
236
_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
237
}
238
239


240
241
@dataclass(eq=True)
class AttentionParams:
242
    """
243
    Attention parameters used to determine which backend to be used.
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    Parameters
    ----------
    qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
        Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
    qkv_dtype: torch.dtype, default = `torch.bfloat16`
        Data type of query/key/value tensors.
    qkv_layout: str, default = "sbh3d"
        Query/key/value tensor memory layout.
    batch_size: int, default = 1
        Batch size.
    num_heads: int, default = 16
        Number of attention heads in the query tensor.
    num_gqa_groups: int, default = 16
        Number of attention heads in key and value tensors.
    max_seqlen_q: int, default = 128
        Maximum sequence length of the query tensor.
    max_seqlen_kv: int, default = 128
        Maximum sequence length of the key and value tensors.
263
264
265
266
    head_dim_qk: int, default = 64
        The size of each attention head in query and key tensors.
    head_dim_v: int, default = 64
        The size of each attention head in the value tensor.
267
268
269
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
        `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
270
    window_size: Tuple[int, int], default = None
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        Sliding window attention size.
    alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
        Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
    core_attention_bias_type: str, default = `no_bias`
        Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
    core_attention_bias_shape: str, default = `1hss`
        Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
    core_attention_bias_requires_grad: bool, default = `True`
        Whether attention bias requires gradient.
    pad_between_seqs: bool, default = `False`
        Whether there is padding between sequences in a batch.
        This only applies to `qkv_format=thd`.
    attention_dropout: float, default = 0.0
        Attention dropout.
    context_parallel: bool, default = `False`
        Whether context parallelism is used or not.
    deterministic: bool, default = `False`
        Whether to run `DotProductAttention` with determinism or not.
289
290
    is_training: bool, default = `True`
        Whether in training mode (`True`) or inference mode (`False`)
291
292
293
294
    fp8: bool, default = `False`
        Whether `DotProductAttention` is in an `fp8_autocast` region.
    fp8_meta: Optional[Dict[str Any]], default = `None`
        The FP8 metadata tensor of `DotProductAttention`.
295
296
297
298
299
300
301
302
303
304
    """

    qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
    qkv_dtype: torch.dtype = torch.bfloat16
    qkv_layout: str = "sbh3d"
    batch_size: int = 1
    num_heads: int = 16
    num_gqa_groups: int = 16
    max_seqlen_q: int = 128
    max_seqlen_kv: int = 128
305
306
    head_dim_qk: int = 64
    head_dim_v: int = 64
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    attn_mask_type: str = "no_mask"
    window_size: Union[Tuple[int, int], None] = None
    alibi_slopes_shape: Union[torch.Size, List, None] = None
    core_attention_bias_type: str = "no_bias"
    core_attention_bias_shape: str = "1hss"
    core_attention_bias_requires_grad: bool = True
    pad_between_seqs: bool = False
    attention_dropout: float = 0.0
    context_parallel: bool = False
    deterministic: bool = False
    is_training: bool = True
    fp8: bool = False
    fp8_meta: Union[Dict[str, Any], None] = None

321
322
323
324
325
326
327
328
329
330
331
332
333
334
    def __eq__(self, other):
        """
        Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared,
        since all other entries of fp8_meta are unused in get_attention_backend.
        """
        if not isinstance(other, self.__class__):
            return NotImplemented
        for field in fields(self):
            fname = field.name
            sf = getattr(self, fname)
            of = getattr(other, fname)
            if fname != "fp8_meta":
                if sf != of:
                    return False
335
            elif sf.get("recipe", None) != of.get("recipe", None):
336
337
338
                return False
        return True

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


355
356
357
358
359
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
    """Make tensor contiguous if final stride is not 1."""
    return tensor.contiguous() if tensor.stride(-1) != 1 else tensor


360
361
362
363
364
365
366
367
368
def get_attention_backend(
    attention_params: AttentionParams = None,
):
    """
    Select the appropriate attention backend/sub-backend based on user input and runtime environment.

    Parameters
    ----------
    See `AttentionParams`.
369
370
371
372
373
374
375

    Returns
    ----------
    use_flash_attention: bool
        Whether the `FlashAttention` backend has been selected.
    use_fused_attention: bool
        Whether the `FusedAttention` backend has been selected.
376
377
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend
        If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
378
379
380
381
382
383
    use_unfused_attention: bool
        Whether the `UnfusedDotProductAttention` backend has been selected.
    available_backends: List[bool]
        All available backends that could support the provided input. A list of Booleans
        in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
    """
384
385
386
387
388
389
390
391
    qkv_type = attention_params.qkv_type
    qkv_dtype = attention_params.qkv_dtype
    qkv_layout = attention_params.qkv_layout
    batch_size = attention_params.batch_size
    num_heads = attention_params.num_heads
    num_gqa_groups = attention_params.num_gqa_groups
    max_seqlen_q = attention_params.max_seqlen_q
    max_seqlen_kv = attention_params.max_seqlen_kv
392
393
    head_dim_qk = attention_params.head_dim_qk
    head_dim_v = attention_params.head_dim_v
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    attn_mask_type = attention_params.attn_mask_type
    window_size = attention_params.window_size
    alibi_slopes_shape = attention_params.alibi_slopes_shape
    core_attention_bias_type = attention_params.core_attention_bias_type
    core_attention_bias_shape = attention_params.core_attention_bias_shape
    core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
    pad_between_seqs = attention_params.pad_between_seqs
    attention_dropout = attention_params.attention_dropout
    context_parallel = attention_params.context_parallel
    deterministic = attention_params.deterministic
    is_training = attention_params.is_training
    fp8 = attention_params.fp8
    fp8_meta = attention_params.fp8_meta

    # Run config
409
    logger = logging.getLogger("DotProductAttention")
410
411
412
    logger.setLevel(_log_level)
    if not logger.hasHandlers():
        logger.addHandler(_stream_handler)
413
414
415
416
417
    device_compute_capability = get_device_compute_capability()
    cudnn_version = get_cudnn_version()
    run_config = {
        "transformer_engine_version": te.__version__,
        "compute_capability": "sm"
418
        + str(10 * device_compute_capability[0] + device_compute_capability[1]),
419
420
421
422
423
424
        "flash_attn_version": (
            str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
        ),
        "flash_attn_3_version": (
            str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed"
        ),
425
426
427
428
429
430
431
432
433
        "cudnn_version": ".".join([str(i) for i in cudnn_version]),
    }
    attention_params_dict = {
        field.name: getattr(attention_params, field.name) for field in fields(attention_params)
    }
    run_config.update(attention_params_dict)
    if fp8:
        run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
    logger.debug("Running with config=%s", run_config)
434

435
436
437
438
439
440
    # The following sections check if `FlashAttention` supports the provided attention params,
    # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
    # necessary for performance/functionality, a warning will be issued to prompt users to
    # install an appropriate FA version.
    global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3

441
    # Filter: Environment variables
442
443
444
445
    use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
    use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
    use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
    if not use_flash_attention and _flash_attn_is_installed:
446
447
448
449
450
451
452
453
        logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
    if not use_fused_attention:
        logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
    if not use_unfused_attention:
        logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")

    # Filter: Compute capability
    if device_compute_capability < (8, 0):
454
        if use_flash_attention and _flash_attn_is_installed:
455
            logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
456
        use_flash_attention = False
457
458
459
        if use_fused_attention:
            logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
            use_fused_attention = False
460
    if device_compute_capability < (9, 0):
461
        if use_flash_attention and _flash_attn_3_is_installed:
462
            logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
463
        _use_flash_attn_3 = False
464
465

    # Filter: Data type
466
467
468
469
    if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
        torch.Tensor,
        Float8Tensor,
    ]:
470
        if use_flash_attention and _flash_attn_is_installed:
471
472
473
474
475
476
            logger.debug(
                "Disabling FlashAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
477
        use_flash_attention = False
478
479
480
481
482
483
484
485
        if use_fused_attention:
            logger.debug(
                "Disabling FusedAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
            use_fused_attention = False
486
487
488

    # Filter: Execution type
    if fp8 and fp8_meta["recipe"].fp8_dpa:
489
        if use_flash_attention and not _use_flash_attn_3:
490
491
            if _flash_attn_is_installed:
                logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
492
493
494
495
496
            use_flash_attention = False
        if use_flash_attention and _use_flash_attn_3 and is_training:
            logger.debug(
                "Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
            )
497
498
499
500
501
502
            use_flash_attention = False
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
            use_unfused_attention = False

    # Filter: Head dimension
503
    if use_flash_attention and head_dim_qk != head_dim_v:
504
505
        if _flash_attn_is_installed:
            logger.debug("Disabling FlashAttention as it does not support MLA.")
506
        use_flash_attention = False
507
    if use_flash_attention and (
508
509
        head_dim_qk > 256
        or head_dim_qk % 8 != 0
510
511
512
513
        or (
            head_dim_qk > 192
            and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
        )
514
    ):
515
516
517
518
        if _flash_attn_is_installed:
            logger.debug(
                "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
                "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
519
                "head_dim_qk <= 256 (>192 requires sm80/90/100+). "
520
521
522
523
524
                "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
                head_dim_qk,
                head_dim_v,
                ".".join([str(i) for i in device_compute_capability]),
            )
525
        use_flash_attention = False
526
527
528
529
530
531
532
    qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
    if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
        logger.debug(
            "Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
            qkv_layout,
        )
        use_fused_attention = False
533
534
535
536
537
538
539
540

    # Filter: QKV layout
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
            use_unfused_attention = False
        if use_flash_attention and pad_between_seqs:
541
542
543
544
545
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention for qkv_format = thd when there is "
                    "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
                )
546
547
            use_flash_attention = False

548
    # Filter: Dropout
549
550
551
    if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
        logger.debug("Disabling FlashAttention 3 for dropout")
        _use_flash_attn_3 = False
552

553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    # Filter: Context parallelism
    # qkv_format | attn_mask_type              | attn_bias_type           | supported backends
    # ----------------------------------------------------------------------------------------------------
    # bshd, sbhd | self-attention:             | no_bias, post_scale_bias | FlashAttention, FusedAttention
    #            |     no_mask, causal         |                          |
    #            | cross-attention:            |                          |
    #            |     no_mask                 |                          |
    # thd        | self-attention:             | no_bias                  | FlashAttention, FusedAttention
    #            |     padding, padding_causal |                          | if no padding between sequences,
    #            | cross-attention:            |                          | FusedAttention
    #            |     padding                 |                          | if there is padding between sequences
    # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
    if context_parallel and use_unfused_attention:
        logger.debug(
            "Disabling UnfusedDotProductAttention as it does not support context parallelism"
        )
        use_unfused_attention = False
    if context_parallel and use_flash_attention:
571
        if fp8 and fp8_meta["recipe"].fp8_dpa:
572
573
574
575
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with FP8"
                )
576
            use_flash_attention = False
577
        if "bottom_right" in attn_mask_type:
578
579
580
581
582
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " causal_bottom_right masking"
                )
583
584
            use_flash_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
585
586
587
588
589
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " causal masking for cross-attention"
                )
590
591
            use_flash_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
592
593
594
595
596
597
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with bias"
                    " type of %s",
                    core_attention_bias_type,
                )
598
599
            use_flash_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
600
601
602
603
604
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " attention bias for THD format"
                )
605
            use_flash_attention = False
606

607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    if context_parallel and use_fused_attention:
        if "bottom_right" in attn_mask_type:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with"
                " causal_bottom_right masking"
            )
            use_fused_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with causal"
                " masking for cross-attention"
            )
            use_fused_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with bias type"
                " of %s",
                core_attention_bias_type,
            )
            use_fused_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with attention"
                " bias for THD format"
            )
            use_fused_attention = False
        elif head_dim_qk != head_dim_v:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with MLA"
            )
            use_fused_attention = False

639
    # Filter: Attention mask
640
641
642
643
644
645
646
647
648
649
650
651
652
653
    # attn_mask_type              | attention_mask                       | supported backends
    # ----------------------------------------------------------------------------------------
    # no_mask                     | None                                 | All
    # padding                     |                                      | All
    #     self-attention          | One tensor in shape [b, 1, 1, sq]    |
    #     cross-attention         | Tuple of two tensors in shapes       |
    #                             | [b, 1, 1, sq] and [b, 1, 1, skv]     |
    # causal                      | None                                 |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # padding_causal              | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # causal_bottom_right         | None                                 | All
654
    # padding_causal_bottom_right | Same as "padding"                    | All
655
656
    # arbitrary                   | One tensor in shape broadcastable to | UnfusedDotProductAttention
    #                             | [b, h, sq, skv]                      |
657
    if attn_mask_type == "arbitrary":
658
        if use_flash_attention and _flash_attn_is_installed:
659
660
661
662
663
            logger.debug("Disabling FlashAttention for arbitrary mask")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention for arbitrary mask")
        use_fused_attention = False
664
665
    if (
        use_flash_attention
666
        and _use_flash_attn_3
667
668
669
670
671
672
673
674
675
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        _use_flash_attn_3 = False
676
677
678
679
680
    if (
        use_flash_attention
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
681
682
683
684
685
686
687
688
689
        if _flash_attn_2_1_plus:
            logger.warning(
                "Disabling FlashAttention as it only supports bottom-right-diagonal "
                "causal mask since flash-attn 2.1. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False
        if not _flash_attn_is_installed:
            _flash_attn_max_version = PkgVersion("2.1")
690
691
692
693
694
    if (
        use_flash_attention
        and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
        and max_seqlen_q != max_seqlen_kv
    ):
695
696
697
698
699
700
701
702
703
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.1")
        elif not _flash_attn_2_1_plus and not _use_flash_attn_3:
            logger.warning(
                "Disabling FlashAttention as it only supports top-left-diagonal "
                "causal mask before flash-attn 2.1. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False
704
705
706
707
708
709
710
711
712
    if (
        use_flash_attention
        and _use_flash_attn_3
        and fp8
        and fp8_meta["recipe"].fp8_dpa
        and "padding" in attn_mask_type
    ):
        logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
        _use_flash_attn_3 = False
713
714

    # Filter: Sliding window attention
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
    #    backend                 |      window_size       | diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | (-1, -1) or (>=0, >=0) | bottom right
    # FusedAttention             | (-1,  0) or (>=0, 0)   | top left
    # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
    #                            |                        | converts window_size to an 'arbitrary' mask
    if window_size is None:
        window_size = check_set_window_size(attn_mask_type, window_size)
    else:
        if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention"
                    " for FP8"
                )
                use_fused_attention = False
731
            elif window_size[1] != 0 or attention_dropout != 0.0:
732
733
                logger.debug(
                    "Disabling FusedAttention as it only supports sliding window attention "
734
                    "with (left, 0) and no dropout"
735
736
                )
                use_fused_attention = False
737
            elif max_seqlen_q > max_seqlen_kv:
738
739
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
740
                    "with s_q > s_kv for cross-attention"
741
742
                )
                use_fused_attention = False
743
744
745
746
747
748
749
750
751
752
753
754
755
        if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if _use_flash_attn_3:
                logger.debug(
                    "Disabling FlashAttention 3 as it does not support sliding window attention"
                )
                _use_flash_attn_3 = False
            if not _flash_attn_is_installed:
                _flash_attn_version_required = PkgVersion("2.3")
            elif not _flash_attn_2_3_plus:
                logger.debug(
                    "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
                )
                use_flash_attention = False
756
757

    # Filter: Attention bias
758
759
760
761
762
763
764
765
    #    backend                 |      bias types              | ALiBi diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | no_bias, alibi/alibi_slopes  | bottom right
    # FusedAttention             | no_bias, post_scale_bias     |
    #                            | alibi/alibi_slopes           | top left,
    #                            |                              | bottom_right (converts to a 'post_scale_bias' bias)
    # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
    #                            | alibi/alibi_slopes           | both; converts to a 'post_scale_bias' bias
766
    if use_flash_attention and core_attention_bias_type == "alibi":
767
        if _use_flash_attn_3:
768
769
            logger.debug("Disabling FlashAttention 3 for ALiBi")
            _use_flash_attn_3 = False
770
771
772
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.4")
        elif not _flash_attn_2_4_plus:
773
774
            logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
            use_flash_attention = False
775

776
777
778
779
    if use_flash_attention and (
        core_attention_bias_type not in ["no_bias", "alibi"]
        or core_attention_bias_shape is not None
    ):
780
781
        if _flash_attn_is_installed:
            logger.debug("Disabling FlashAttention for pre/post_scale_bias")
782
783
784
785
786
787
788
789
        use_flash_attention = False

    fu_core_attention_bias_type = core_attention_bias_type
    fu_core_attention_bias_shape = core_attention_bias_shape
    fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
    if (
        use_fused_attention
        and core_attention_bias_type == "alibi"
790
        and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
791
792
793
    ):
        fu_core_attention_bias_type = "post_scale_bias"
        fu_core_attention_bias_requires_grad = False
794
795
796
797
798
        if alibi_slopes_shape is None:
            fu_core_attention_bias_shape = "1hss"
        elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
            fu_core_attention_bias_shape = "1hss"
        elif (
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
            len(alibi_slopes_shape) == 2
            and alibi_slopes_shape[0] == batch_size
            and alibi_slopes_shape[1] == num_heads
        ):
            fu_core_attention_bias_shape = "bhss"

    if (
        use_fused_attention
        and fu_core_attention_bias_type == "post_scale_bias"
        and fu_core_attention_bias_shape != "1hss"
    ):
        if fu_core_attention_bias_requires_grad:
            # remove this line when cuDNN adds bwd support for
            # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
            logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
            use_fused_attention = False
        else:
            # max512 backend will only support [1, h, s, s]
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

    # Filter: cuDNN support
    fused_attention_backend = None
    if use_fused_attention:
        q_type = TE_DType[qkv_dtype]
        kv_type = q_type
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            kv_type = q_type
        fused_attention_backend = tex.get_fused_attn_backend(
            q_type,
            kv_type,
            QKVLayout[qkv_layout],
            AttnBiasType[fu_core_attention_bias_type],
            AttnMaskType[attn_mask_type],
            attention_dropout,
            num_heads,
            num_gqa_groups,
            max_seqlen_q,
            max_seqlen_kv,
838
839
            head_dim_qk,
            head_dim_v,
840
841
            window_size[0],
            window_size[1],
842
        )
843
        if fused_attention_backend == FusedAttnBackend["No_Backend"]:
844
845
            logger.debug("Disabling FusedAttention as no backend supports the provided input")
            use_fused_attention = False
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
            fused_attention_backend = None
        if (
            use_fused_attention
            and window_size is not None
            and window_size[0] != -1
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "slidng window attention",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
863
864
865
866
867
868
869
870
            and fu_core_attention_bias_type == "post_scale_bias"
            and fu_core_attention_bias_shape != "1hss"
        ):
            logger.debug(
                "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
                " [1, H, S, S] shape"
            )
            use_fused_attention = False
871
            fused_attention_backend = None
872
873
874
875
876
877
878
879
880
881
882
883
884

    # Filter: Determinism
    # backend                      | deterministic
    # ---------------------------------------------
    # FlashAttention               |
    #     flash-attn >=2.0, <2.4.1 | no
    #     flash-attn >=2.4.1       | yes
    # FusedAttention               |
    #     sub-backend 0            | yes
    #     sub-backend 1            | workspace optimization path and sm90+: yes;
    #                              | otherwise: no
    #     sub-backend 2            | no
    # UnfusedDotProductAttention   | yes
885
886
887
888
889
890
891
892
893
894
    if use_flash_attention and deterministic:
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.4.1")
        elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3:
            logger.warning(
                "Disabling FlashAttention as version <2.4.1 does not support deterministic "
                "execution. To use FlashAttention with deterministic behavior, "
                "please install flash-attn >= 2.4.1."
            )
            use_flash_attention = False
895
896
897
898
899
900
901
902
903
904
905
    if use_fused_attention and deterministic:
        if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (
                device_compute_capability < (9, 0)
                or core_attention_bias_requires_grad
                or cudnn_version < (8, 9, 5)
906
            )
907
908
909
        ):
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
910
911
912

    # All available backends
    available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929

    # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
    # When `FusedAttention` does not support the provided attention params, and `FlashAttention`
    # does, we recommend users to install flash-attn if not installed already.
    if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed:
        logger.warning(
            "flash-attn may provide important feature support or performance improvement."
            " Please install flash-attn %s.",
            _get_supported_versions(
                _flash_attn_version_required,
                _flash_attn_max_version,
            ),
        )
    if use_flash_attention and not _flash_attn_is_installed:
        use_flash_attention = False
        available_backends[0] = False

930
931
932
933
934
935
936
937
938
939
940
941
    logger.debug(
        "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
        " UnfusedDotProductAttention=%s}",
        bool(available_backends[0]),
        bool(available_backends[1]),
        (
            f" (sub-backend {int(fused_attention_backend)})"
            if fused_attention_backend is not None
            else ""
        ),
        bool(available_backends[2]),
    )
942
943
944
945
946
947
948

    # Select FusedAttention for performance
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
    ):
949
        if device_compute_capability >= (9, 0):
950
951
952
953
954
            logger.debug(
                "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                "for performance reasons"
            )
            use_flash_attention = False
955
956
957
958
959
960
961
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["FP8"]
        and _use_flash_attn_3
    ):
        logger.debug(
962
963
            "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
            "in FP8 execution"
964
965
966
        )
        use_flash_attention = False

967
968
969
970
971
972
    # Selected backend
    if use_flash_attention:
        use_fused_attention = False
        use_unfused_attention = False
    elif use_fused_attention:
        use_unfused_attention = False
973
    selected_backend = "NoBackend"
974
975
976
977
978
979
    if use_flash_attention:
        selected_backend = "FlashAttention"
    elif use_fused_attention:
        selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
    elif use_unfused_attention:
        selected_backend = "UnfusedDotProductAttention"
980
    logger.debug("Selected backend = %s", selected_backend)
981

982
983
984
985
986
987
    global _attention_backends
    _attention_backends["use_flash_attention"] = use_flash_attention
    _attention_backends["use_fused_attention"] = use_fused_attention
    _attention_backends["fused_attention_backend"] = fused_attention_backend
    _attention_backends["use_unfused_attention"] = use_unfused_attention
    _attention_backends["backend_selection_requires_update"] = False
988
989
990
991

    return (
        use_flash_attention,
        use_fused_attention,
992
        fused_attention_backend,
993
994
995
996
997
        use_unfused_attention,
        available_backends,
    )


998
class InferenceParams:  # pylint: disable=too-few-public-methods
999
1000
    """
    Inference parameters that are passed to the main model in order
1001
    to efficiently calculate and store the context during inference.
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
1042

1043

1044
@torch.no_grad()
1045
def get_full_mask(
1046
1047
1048
    max_seqlen_q: int,
    max_seqlen_kv: int,
    attn_mask_type: str = "no_mask",
1049
1050
1051
1052
    attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
    window_size: Tuple[int, int] = None,
    attention_type: str = "self",
    bottom_right_alignment: bool = True,
1053
1054
) -> torch.Tensor:
    """
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
    Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
    `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
    on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::

       attn_mask_type              output shape                                 diagonal alignment
       --------------------------------------------------------------------------------------------
       no_mask                     [1, 1, max_seqlen_q, max_seqlen_kv]          follow bottom_right_alignment
       causal                      [1, 1, max_seqlen_q, max_seqlen_kv]          always top left
       causal_bottom_right         [1, 1, max_seqlen_q, max_seqlen_kv]          always bottom right
       padding                     [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
       padding_causal              [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
       padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
       arbitrary                   same as attention_mask                       follow bottom_right_alignment

    .. note::

    For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
    diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
    i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
    max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
    [[False, False,  True, True], [False, False, False, False]],
    [[False, False, False, True], [False,  True,  True,  True]]), the returned full attention mask has [2, 4, 4]
    shape and is,::

      [[[False, False, False, True],
        [False, False, False, True],
        [ True,  True,  True, True],
        [ True,  True,  True, True]],
       [[False,  True,  True, True],
        [False,  True,  True, True],
        [False,  True,  True, True],
        [False,  True,  True, True]]]
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096

    Parameters
    ----------
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
        "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
1097
    attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
1098
        default = `None`
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
        Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
        for the requirements of `attention_mask` for different `attn_mask_type`s.
    window_size: Tuple[int, int], default = `None`
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
        window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
        map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
        `attn_mask_type`.
    attention_type: str, default = "self"
        Attention type, {"self", "cross"}
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
        or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
        specifies "causal" or "causal_bottom_right".
1114
1115
1116

    Returns
    ----------
1117
1118
    attn_mask_type: str
        For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
1119
    attention_mask: torch.Tensor
1120
1121
1122
1123
1124
1125
1126
        The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
    actual_seqlens_q: torch.Tensor
        For padding masks, the actual sequence lengths for queries, in shape [batch_size].
        For other masks, `None`.
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
        For other masks, `None`.
1127
    """
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    # perform basic checks
    change_type = window_size is not None and (
        window_size[0] != -1 or window_size[1] not in [-1, 0]
    )
    if window_size is None:
        window_size = (-1, -1)
    if "causal" in attn_mask_type:
        window_size = (window_size[0], 0)
    window_size = (
        max_seqlen_kv if window_size[0] == -1 else window_size[0],
        max_seqlen_q if window_size[1] == -1 else window_size[1],
    )

    # apply padding mask
    actual_seqlens_q = None
    actual_seqlens_kv = None
    if "padding" in attn_mask_type:
        if attention_type == "self":
            attention_mask = torch.logical_or(
                attention_mask.squeeze(1).unsqueeze(3), attention_mask
            )
        else:
            attention_mask = torch.logical_or(
                attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
            )
        m = attention_mask.logical_not()
        actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
        actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)

    # apply SWA mask
    mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
        1, 1, max_seqlen_q, 1
    ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
    swa_left = None
    swa_right = None
    if attn_mask_type == "causal_bottom_right" or (
        attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment
    ):
        swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
        swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
    elif attn_mask_type in ["causal", "padding_causal"] or (
        attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment
    ):
        swa_left = mask - window_size[0]
        swa_right = mask + window_size[1]
    elif attn_mask_type == "padding_causal_bottom_right" or (
        attn_mask_type == "padding" and bottom_right_alignment
    ):
        batch_size = attention_mask.shape[0]
        swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
            actual_seqlens_kv - actual_seqlens_q - window_size[0]
        ).view(batch_size, 1, 1, 1)
        swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
            actual_seqlens_kv - actual_seqlens_q + window_size[1]
        ).view(batch_size, 1, 1, 1)
    swa_mask = torch.logical_not(
        torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
    )
1186
    if attention_mask is not None:
1187
1188
1189
1190
1191
1192
1193
1194
1195
        attention_mask = torch.logical_or(swa_mask, attention_mask)
    else:
        attention_mask = swa_mask

    # change mask type
    if change_type:
        attn_mask_type = "arbitrary"

    return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv
1196
1197


1198
1199
1200
1201
1202
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
1203
1204
    actual_seqlens_q: Optional[torch.Tensor] = None,
    actual_seqlens_kv: Optional[torch.Tensor] = None,
1205
1206
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
1207
    bottom_right_alignment: bool = True,
1208
) -> Tuple[torch.Tensor, torch.Tensor]:
1209
    """
1210
1211
1212
1213
1214
1215
1216
1217
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
1218
1219
1220
1221
    actual_seqlens_q: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for queries, in shape [batch_size].
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for keys and values, in shape [batch_size].
1222
1223
1224
1225
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
1226
1227
1228
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the ALiBi bias to the bottom right corner of
        the matrix (`True`) or top left (`False`).
1229

1230
1231
1232
1233
1234
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
1235
1236
1237
1238
1239
1240
        ALiBi bias in FP32 or `bias_dtype`. Its shape is
        (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
        and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
        (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
        [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
        `actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
1264
        elif _alibi_cache["_alibi_slopes"].dim() == 2:
1265
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
1266
1267
1268
        else:
            raise ValueError("ALiBi slopes cannot exceed 2 dimensions.")

1269
        bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1270
            1, 1, max_seqlen_q, 1
1271
1272
        ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
            1, 1, 1, max_seqlen_kv
1273
        )
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
        if actual_seqlens_q is None and actual_seqlens_kv is None:
            if bottom_right_alignment:
                bias = bias + max_seqlen_kv - max_seqlen_q
        elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
            batch_size = actual_seqlens_q.shape[0]
            bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
            if bottom_right_alignment:
                bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
        else:
            assert (
                False
            ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
1286
1287
1288
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
1289
        _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
1290
1291
1292
1293
1294
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
1295
1296
1297
1298
1299
1300
1301
1302
1303


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
1304
    reduced_mask = mask.logical_not().sum(dim=1)
1305
1306
1307
1308
1309
1310
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

1311

1312
1313
1314
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
1315
1316
1317
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
1318
1319
1320
1321
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

1322
    reduced_mask = mask.logical_not().sum(dim=1)
1323
1324
1325
1326
1327
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
1328
    indices = mask.logical_not().nonzero()
1329
1330
1331
1332
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
1333
1334
1335
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
1336
1337
1338
1339

    return cu_seqlens, indices


1340
1341
1342
1343
1344
1345
1346
1347
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
1348
1349
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
1350
1351
1352

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
1353
1354
1355
1356
1357
1358
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
1359
1360
1361

    return indices

1362

1363
_cu_seqlens_cache = {}
1364
1365


1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
1386
1387


1388
@jit_fuser
1389
1390
1391
1392
1393
1394
1395
1396
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
1397
1398
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1399
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
1400
1401
    if isinstance(tensor, Float8Tensor):
        tensor_data = torch.cat((tensor._data, padding_indice), dim=0)
1402
        gathered_data = torch.gather(tensor_data, 0, indices)
1403

1404
        packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape)
1405
1406
1407
1408
    else:
        tensor = torch.cat((tensor, padding_indice), dim=0)

        packed = torch.gather(tensor, 0, indices)
1409
1410
1411
    return packed


1412
@jit_fuser
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


1426
@jit_fuser
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


1442
@jit_fuser
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
1453
1454
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1455
1456
    if isinstance(tensor, Float8Tensor):
        unpacked.scatter_(0, indices, tensor._data)
1457
1458
        unpacked_data = unpacked[0:-1, :, :]
        unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape)
1459
1460
1461
    else:
        unpacked.scatter_(0, indices, tensor)
        unpacked = unpacked[0:-1, :, :]
1462
1463
1464
    return unpacked


1465
@jit_fuser
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


1480
@jit_fuser
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
1501

1502
1503
    @staticmethod
    def forward(
1504
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
1505
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
1506
        # pylint: disable=missing-function-docstring
1507
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
1508
        ctx.save_for_backward(indices)
1509
1510
1511
1512
1513
1514
1515
1516
1517
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
1518
        # pylint: disable=missing-function-docstring
1519
        (indices,) = ctx.saved_tensors
1520
        if len(grad_outputs) == 1:
1521
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
1522
        if len(grad_outputs) == 2:
1523
1524
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
1525
1526
1527
1528
1529
1530


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
1531

1532
1533
1534
1535
1536
1537
1538
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
1539
        # pylint: disable=missing-function-docstring
1540
        ctx.save_for_backward(indices)
1541
1542
1543
1544
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
1545
        # pylint: disable=missing-function-docstring
1546
1547
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
1548
1549


1550
1551
1552
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
1553
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
1554
1555
1556
1557
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
1558
1559
1560
1561
1562
1563
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
1564
1565
1566
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
1567
1568
1569
1570
1571
1572
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
@jit_fuser
def flash_attn_fwd_out_correction_init(
    out_init_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_init_step: torch.Tensor,
    seq_dim: int,
):
    """Merge partial outputs of the first step in Attention with context parallelism"""
    softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_init_step * softmax_lse_corrected_exp
    return out_corrected.to(out_init_step.dtype)


1606
@jit_fuser
1607
1608
1609
1610
1611
def flash_attn_fwd_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
1612
    seq_dim: int,
1613
):
1614
    """Merge partial outputs of each step in Attention with context parallelism"""
1615
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
1616
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
1617
    out_corrected = out_per_step * softmax_lse_corrected_exp
1618
1619
1620
    out.add_(out_corrected)


1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
@jit_fuser
def flash_attn_fwd_second_half_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
    seq_dim: int,
):
    """Merge second half of partial outputs of each step in Attention with context parallelism"""
    out_ = out.select(seq_dim, 1)
    softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :]
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_per_step * softmax_lse_corrected_exp
    out_.add_(out_corrected)


1638
@jit_fuser
1639
1640
1641
1642
def flash_attn_fwd_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
1643
    """Merge softmax stats of each step in Attention with context parallelism"""
1644
1645
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
1646
    new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
1647
    softmax_lse.copy_(new_scale)
1648
1649


1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
@jit_fuser
def flash_attn_fwd_second_half_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
    """Merge second half of softmax stats of each step in Attention with context parallelism"""
    softmax_lse_ = softmax_lse[..., 1, :]
    max_scale = torch.max(softmax_lse_, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse_, softmax_lse_per_step)
    new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
    softmax_lse_.copy_(new_scale)


1663
1664
@jit_fuser
def get_cu_seqlens_on_cp_rank(
1665
1666
1667
1668
1669
1670
    cu_seqlens: torch.Tensor,
    cu_seqlens_padded_on_cp_rank: torch.Tensor,
    cp_size: int,
    cp_rank: int,
    first_half: bool,
    second_half: bool,
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


1689
@jit_fuser
1690
def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device):
1691
1692
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
1693
1694
1695
    To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to
    be contigupus before attention compute. This function is to compute sequence chunk ids for
    reordering.
1696
1697
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
1698
1699
1700
    for rank in range(cp_size):
        chunk_ids[rank] = 2 * rank
        chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
1701
1702
1703
    return chunk_ids


1704
@jit_fuser
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device):
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
    We need to reorder sequence chunks back to discontiguous after attention compute. This function
    is to compute sequence chunk ids for reordering.
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
    for rank in range(cp_size):
        chunk_ids[2 * rank] = rank
        chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
    return chunk_ids


@jit_fuser
def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
    """Reorder sequence chunk for A2A communication before attention compute."""
    # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
    # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
    x = x.movedim(0, seq_dim).contiguous()
    # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
    # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
    x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
    # reorder the sequence chunks
    x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
    return x


@jit_fuser
def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
    """Reorder sequence chunk for A2A communication after attention compute."""
    # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
    # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
    x = x.movedim(seq_dim, 0).contiguous()
    # reorder the sequence chunks
    x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
    # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
    # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
    x = x.view(cp_size, 2, *x.shape[1:])
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
    return x


def flash_attn_a2a_communicate(
    a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
    chunk_ids_for_a2a: torch.Tensor,
    seq_dim: int,
    cp_size: int,
    cp_group: dist_group_type,
    cp_stream: torch.cuda.Stream,
    before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
    """A2A communication for context parallelism."""
    a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
    a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
    if before_attn:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # reorder the sequence chunks
1770
1771
                    x = reorder_seq_chunks_for_a2a_before_attn(
                        x, chunk_ids_for_a2a, seq_dim, cp_size
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
                    )
                    # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
                    # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
                    a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, s, np, hn] -> [b, s, cp, np//cp, hn]
                # or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
                x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
                # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
                # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
                a2a_inputs[i] = x.movedim(-3, 0).contiguous()
    else:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
                # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
                x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
                # reorder the sequence chunks
1797
1798
                a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn(
                    x, chunk_ids_for_a2a, seq_dim, cp_size
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
                    # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
                    x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
                    # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
                    # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
                    a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
    torch.cuda.current_stream().wait_stream(cp_stream)
    return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs


1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
    """Get the list of quantizers used in attention from the quantizers list."""
    if not fp8:
        num_of_nones = 8 if cp_specific_quantizers else 6
        return [None] * num_of_nones
    QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
    QKV_quantizer.internal = True
    QKV_quantizer.set_usage(rowwise=True, columnwise=False)
    O_quantizer = quantizers["scaling_fwd"][META_O]
    O_quantizer.set_usage(rowwise=True, columnwise=False)
    S_quantizer = quantizers["scaling_fwd"][META_S]
    S_quantizer.internal = True
    S_quantizer.set_usage(rowwise=True, columnwise=False)
    dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
    dQKV_quantizer.interal = True
    dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
    dO_quantizer = quantizers["scaling_bwd"][META_DO]
    dO_quantizer.set_usage(rowwise=True, columnwise=False)
    dO_quantizer.internal = True
    dP_quantizer = quantizers["scaling_bwd"][META_DP]
    dP_quantizer.set_usage(rowwise=True, columnwise=False)
    dP_quantizer.interal = True
    dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
    dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False)
    dQKV_CP_quantizer.internal = True
    O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
    O_CP_quantizer.set_usage(rowwise=True, columnwise=False)

    if cp_specific_quantizers:
        return (
            QKV_quantizer,
            O_quantizer,
            O_CP_quantizer,
            S_quantizer,
            dQKV_quantizer,
            dQKV_CP_quantizer,
            dO_quantizer,
            dP_quantizer,
        )

    return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer


1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
_cu_seqlens_info_with_cp_cache = {}


def _get_cu_seqlens_info_with_cp(
    batch_size: int,
    max_seqlen: int,
    cp_size: int,
    cu_seqlens: torch.Tensor,
):
    """Cumulative sequence lengths with CP being considered."""
    global _cu_seqlens_info_with_cp_cache
    if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache:
        _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = (
            cu_seqlens // cp_size,
            cu_seqlens // (cp_size * 2),
        )
    return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)]


1876
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
1877
    """
1878
1879
1880
    Attention implementation with context parallelism. Exchange KV between CP ranks
    with P2P in ring topology. Split attention compute into multiple steps, and overlap
    current-step compute with next-step communication.
1881
1882
1883
1884
1885

    This implementation also supports hierarchical CP, which parallelizes attention
    heads in low-level CP groups and parallelizes sequence dimension in high-level CP
    groups. For more details, please refer to `LongVILA <https://arxiv.org/abs/2408.10188>`_
    and `USP <https://arxiv.org/abs/2405.07719>`_.
1886
1887
1888
    """

    @staticmethod
1889
1890
1891
1892
1893
1894
1895
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
1896
        cu_seqlens_kv,
1897
        max_seqlen_q,
1898
        max_seqlen_kv,
1899
1900
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
1901
1902
1903
1904
1905
1906
1907
1908
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
1909
1910
        fp8,
        fp8_meta,
1911
1912
1913
        cp_group,
        cp_global_ranks,
        cp_stream,
1914
        quantizers,
1915
        pad_between_seqs,
1916
    ):
1917
        # pylint: disable=missing-function-docstring
1918
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
1919
1920
1921
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
        if isinstance(cp_group, list):
            assert (
                qkv_format != "thd"
            ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
            assert attn_bias_type == "no_bias", (
                f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
                " yet!"
            )
            cp_group_a2a = cp_group[0]
            cp_size_a2a = get_distributed_world_size(cp_group_a2a)
            rank_a2a = get_distributed_rank(cp_group_a2a)
            cp_group = cp_group[1]
        else:
            cp_group_a2a = None
            cp_size_a2a = 1
            rank_a2a = 0

1939
1940
        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
1941
1942
        send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
1943
1944
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1945
1946
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
1947

1948
        batch_dim = None
1949
        seq_dim = None
1950
        cu_seqlens_q_half, cu_seqlens_kv_half = None, None
1951
        if qkv_format in ["bshd", "sbhd"]:
1952
            seq_dim = qkv_format.index("s")
1953
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
1954
1955
1956
1957
1958
1959
1960
1961
1962
            cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None
            if use_fused_attention:
                batch_dim = qkv_format.index("b")
                cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp(
                    q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q
                )
                cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp(
                    q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv
                )
1963
1964
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
1965
1966
            cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
            cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
1967
1968
1969
1970
1971

        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
1972

1973
        fused_attn_backend = None
1974
        qkv_dtype = q.dtype
1975
1976
1977
        amax_per_step = None
        S_quantizer_per_step = [None for _ in range(cp_size)]
        O_CP_quantizer_per_step = [None for _ in range(cp_size)]
1978
1979
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
        is_output_fp8 = False

        (
            QKV_quantizer,
            O_quantizer,
            O_CP_quantizer,
            S_quantizer,
            dQKV_quantizer,
            dQKV_CP_quantizer,
            dO_quantizer,
            dP_quantizer,
        ) = get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True)

1993
1994
1995
        if fp8:
            if use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
1996

1997
1998
1999
2000
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
2001
2002
2003
2004
2005
                is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
                if is_input_fp8:
                    QKV_quantizer = q._quantizer
                    q, k, v = q._data, k._data, v._data
                else:
2006
2007
                    q_f16, k_f16, v_f16 = q, k, v
                    if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
2008
                        q = QKV_quantizer(q_f16)._data
2009
                    if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
2010
2011
2012
2013
2014
2015
2016
2017
                        k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]]
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
                # partial result quantizer
                for i in range(cp_size):
                    S_quantizer_per_step[i] = S_quantizer.copy()
                    S_quantizer_per_step[i].amax = amax_per_step[0][i]
                    O_CP_quantizer_per_step[i] = O_CP_quantizer.copy()
                    O_CP_quantizer_per_step[i].amax = amax_per_step[1][i]
2018
2019
2020
2021
2022
2023
2024
2025
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            q_f16 = q
            if use_fused_attention:
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if cp_size_a2a > 1:
2026
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device)
2027

2028
2029
2030
2031
2032
            q, k, v = flash_attn_a2a_communicate(
                [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
            )
            if not fp8:
                q_f16 = q
2033
            elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
2034
                q_f16 = q
2035
                q = QKV_quantizer(q_f16)._data
2036

2037
2038
2039
        assert qkv_format == "thd" or (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"
2040
        if causal:
2041
2042
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
2043
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
2044
2045
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
2046
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
2047
        if attn_bias is not None:
2048
            assert len(attn_bias.shape) == 4, (
2049
2050
2051
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
2052
2053
2054
            assert (
                attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
            ), "Sequence length does not meet divisible requirements!"
2055
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
2056
2057
2058
2059
2060
2061
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
2062
2063
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
2064
2065
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
2066
            )
2067
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
2068

2069
2070
2071
2072
2073
2074
2075
        softmax_lse_in_packed_format = False
        if qkv_format == "thd":
            if use_fused_attention:
                softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
            else:
                softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3

2076
        flash_attn_fwd = None
2077
2078
2079
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
2080
2081
2082
2083
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
2084
2085
                fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
            else:
2086
2087
2088
2089
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
2090
2091
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
2092
                if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3:
2093
                    fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
2094
2095
2096
                elif _flash_attn_2_7_0_plus:
                    fa_forward_kwargs["window_size_left"] = -1
                    fa_forward_kwargs["window_size_right"] = 0 if causal else -1
2097
2098
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
2099
                if _flash_attn_2_5_7_plus and qkv_format == "thd":
2100
                    fa_forward_kwargs["block_table"] = None
2101
2102
                if _flash_attn_2_6_0_plus:
                    fa_forward_kwargs["softcap"] = 0.0
2103

2104
2105
2106
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
2107
        attn_bias_inputs = [None, None]
2108
2109
2110
2111
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
2112
        attn_biases = [None for _ in range(cp_size)]
2113
2114
2115
2116
2117
2118
2119

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
2120
        if qkv_format in ["bshd", "sbhd"]:
2121
2122
2123
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
2124
2125
        send_recv_reqs = [[], []]

2126
        out = None
2127
        for i in range(cp_size + 1):
2128
            if i < cp_size:
2129
                with torch.cuda.stream(flash_attn_streams[i % 2]):
2130
                    # wait until KV is received
2131
                    for req in send_recv_reqs[(i + 1) % 2]:
2132
2133
                        req.wait()

2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

2146
                    if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
2147
2148
2149
                        kv_inputs[i % 2] = p2p_comm_buffers[i]
                    else:
                        # KV exchange is in BF16/FP16, cast received KV in each step
2150
                        kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
2151
2152
                    if causal:
                        if i == 0:
2153
                            if pad_between_seqs:
2154
2155
2156
2157
2158
2159
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
2160
2161
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
2162
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2163
2164
2165
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
2182
                            if use_fused_attention:
2183
2184
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2185
2186
2187
2188
2189
2190
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
2191
                                    ).contiguous()
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
2204
                                fp8_meta_kwargs = {}
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
2215
2216
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
2217

2218
2219
2220
2221
2222
2223
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2224
2225
2226
2227
2228
                                    q_part,
                                    k_part,
                                    v_part,
                                    fake_dtype=qkv_dtype,
                                    fused_attention_backend=fused_attn_backend,
2229
2230
2231
2232
2233
2234
2235
2236
2237
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
2238
                                )
2239
2240
2241
2242
2243
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2244
                            else:
2245
2246
2247
2248
2249
2250
2251
2252
                                fa_forward_args_thd = []
                                if qkv_format == "thd":
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q,
                                        max_seqlen_kv,
                                    ]
2253
                                fa_outputs = flash_attn_fwd(
2254
                                    q_inputs[i % 2],
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2266
                                    causal=True,
2267
                                    **fa_forward_kwargs,
2268
                                )
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
                                if not _flash_attn_2_7_0_plus:
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[3]
2279
                        elif i <= rank:
2280
                            if pad_between_seqs:
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
2292
2293
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
2294
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
2295
2296
2297
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][0]
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
                                # [2, t, np, hn] -> [2, t/2, np, hn]
                                kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                    kv_inputs[i % 2], cu_seqlens_kv_padded, 0
                                )
2314
                            if use_fused_attention:
2315
                                kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
2316
2317
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2318
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
2331
                                fp8_meta_kwargs = {}
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
2342
2343
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
2344
2345
2346
2347
2348
2349
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv // 2,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2350
2351
2352
2353
                                    q_part,
                                    k_part,
                                    v_part,
                                    qkv_dtype,
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=(
                                        None
                                        if cu_seqlens_kv_padded is None
                                        else cu_seqlens_kv_padded // 2
                                    ),
                                    **fp8_meta_kwargs,
2368
                                )
2369
2370
2371
2372
2373
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2374
                            else:
2375
                                fa_forward_args_thd = []
2376
                                if qkv_format == "thd":
2377
2378
2379
2380
2381
2382
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q,
                                        max_seqlen_kv // 2,
                                    ]
2383
2384
2385
                                if _use_flash_attn_3 or (
                                    _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                                ):
2386
                                    fa_forward_kwargs["window_size"] = (-1, -1)
2387
2388
2389
                                elif _flash_attn_2_7_0_plus:
                                    fa_forward_kwargs["window_size_left"] = -1
                                    fa_forward_kwargs["window_size_right"] = -1
2390
                                fa_outputs = flash_attn_fwd(
2391
                                    q_inputs[i % 2],
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2403
                                    causal=False,
2404
                                    **fa_forward_kwargs,
2405
                                )
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
                                if not _flash_attn_2_7_0_plus:
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[3]
2416
                        else:
2417
                            if pad_between_seqs:
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
2429
2430
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
2431
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2432
2433
2434
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q_half
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                q_inputs[i % 2] = q[:, 1, ...]
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                                q_inputs[i % 2] = q[1]
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                # [t, np, hn] -> [t/2, np, hn]
                                q_inputs[i % 2] = tex.thd_read_half_tensor(
                                    q, cu_seqlens_q_padded, 1
                                )
2454
                            if use_fused_attention:
2455
                                q_inputs[i % 2] = q_inputs[i % 2].contiguous()
2456
2457
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2458
2459
2460
2461
2462
2463
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
2464
                                    ).contiguous()
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
2477
                                fp8_meta_kwargs = {}
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
2488
2489
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
2490
2491
2492
2493
2494
2495
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q // 2,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2496
2497
2498
2499
                                    q_part,
                                    k_part,
                                    v_part,
                                    qkv_dtype,
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=(
                                        None
                                        if cu_seqlens_q_padded is None
                                        else cu_seqlens_q_padded // 2
                                    ),
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
2514
                                )
2515
2516
2517
2518
2519
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2520
                            else:
2521
                                fa_forward_args_thd = []
2522
                                if qkv_format == "thd":
2523
2524
2525
2526
2527
2528
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q // 2,
                                        max_seqlen_kv,
                                    ]
2529
2530
2531
                                if _use_flash_attn_3 or (
                                    _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                                ):
2532
                                    fa_forward_kwargs["window_size"] = (-1, -1)
2533
2534
2535
                                elif _flash_attn_2_7_0_plus:
                                    fa_forward_kwargs["window_size_left"] = -1
                                    fa_forward_kwargs["window_size_right"] = -1
2536
                                fa_outputs = flash_attn_fwd(
2537
                                    q_inputs[i % 2],
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
2549
                                    causal=False,
2550
                                    **fa_forward_kwargs,
2551
                                )
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
                                if not _flash_attn_2_7_0_plus:
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
                                    if not _use_flash_attn_3:
                                        rng_states[i] = fa_outputs[3]
2562
                    else:
2563
                        if pad_between_seqs:
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
2575
2576
                        elif qkv_format == "thd":
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
2577
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2578
2579
2580
                        else:
                            cu_seqlens_q_per_step[i] = cu_seqlens_q
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv
2581
                        if use_fused_attention:
2582
2583
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
2584
2585
2586
2587
2588
2589
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
2590
                                ).contiguous()
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602

                            q_part = q
                            k_part = (
                                kv_inputs[i % 2][..., 0, :, :]
                                if qkv_format in ["bshd", "sbhd"]
                                else kv_inputs[i % 2][0]
                            )
                            v_part = (
                                kv_inputs[i % 2][..., 1, :, :]
                                if qkv_format in ["bshd", "sbhd"]
                                else kv_inputs[i % 2][1]
                            )
2603
                            fp8_meta_kwargs = {}
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
                            if fp8:
                                q_part = QKV_quantizer.create_tensor_from_data(
                                    q_part, fake_dtype=qkv_dtype, internal=True
                                )
                                k_part = QKV_quantizer.create_tensor_from_data(
                                    k_part, fake_dtype=qkv_dtype, internal=True
                                )
                                v_part = QKV_quantizer.create_tensor_from_data(
                                    v_part, fake_dtype=qkv_dtype, internal=True
                                )
2614
2615
                                fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
2616
2617
2618
2619
2620
2621
                            out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                is_training,
                                max_seqlen_q,
                                max_seqlen_kv,
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
2622
2623
2624
2625
                                q_part,
                                k_part,
                                v_part,
                                qkv_dtype,
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
                                fused_attn_backend,
                                attn_scale=softmax_scale,
                                dropout=dropout_p,
                                qkv_layout=qkv_layout,
                                attn_mask_type=attn_mask_type,
                                attn_bias_type=attn_bias_type,
                                attn_bias=attn_bias_inputs[i % 2],
                                cu_seqlens_q_padded=cu_seqlens_q_padded,
                                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                **fp8_meta_kwargs,
2636
                            )
2637
2638
2639
2640
2641
                            if fp8:
                                softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                            else:
                                softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                attn_biases[i] = rest[0] if len(rest) > 0 else None
2642
                        else:
2643
2644
2645
2646
2647
2648
2649
2650
                            fa_forward_args_thd = []
                            if qkv_format == "thd":
                                fa_forward_args_thd = [
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                ]
2651
                            fa_outputs = flash_attn_fwd(
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                *fa_forward_args_thd,
2664
                                causal=False,
2665
                                **fa_forward_kwargs,
2666
                            )
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
                            if not _flash_attn_2_7_0_plus:
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[7]
                            else:
                                out_per_step[i] = fa_outputs[0]
                                softmax_lse_per_step[i] = fa_outputs[1]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[3]
2677
2678
2679
2680

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
2681
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
2682

2683
                if use_fused_attention:
2684
2685
                    # [b, np, sq, 1] -> [b, np, sq] or
                    # [t, np, 1] -> [t, np]
2686
                    softmax_lse_per_step[i - 1].squeeze_(-1)
2687
2688
2689
2690
                    if softmax_lse_in_packed_format:
                        softmax_lse_per_step[i - 1] = (
                            softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
                        )
2691

2692
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
2693
                    if fp8:
2694
                        out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
2695
2696
                    if i == 1:
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
2697
2698
                        if qkv_format == "thd":
                            out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
2699
2700
2701
2702
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
2703
                    else:
2704
                        if qkv_format == "thd":
2705
                            tex.thd_second_half_lse_correction(
2706
2707
2708
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
2709
                                softmax_lse_in_packed_format,
2710
                            )
2711
                        else:
2712
2713
2714
                            flash_attn_fwd_second_half_softmax_lse_correction(
                                softmax_lse.view(*softmax_lse.shape[:-1], 2, -1),
                                softmax_lse_per_step[i - 1],
2715
                            )
2716
2717

                if i < cp_size:
2718
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
2719
2720
2721

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

2722
2723
2724
2725
        second_half_lse_seqlen = None
        if causal and rank < (cp_size - 1):
            second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1]

2726
2727
        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
2728
            if i <= rank or not causal:
2729
                if qkv_format in ["bshd", "sbhd"]:
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
                    if i == 0:
                        out = flash_attn_fwd_out_correction_init(
                            out_per_step[0],
                            softmax_lse,
                            softmax_lse_per_step[0],
                            seq_dim,
                        )
                        out = out.view(q.shape)
                    else:
                        flash_attn_fwd_out_correction(
                            out.view(*out_per_step[i].shape),
                            out_per_step[i],
                            softmax_lse,
                            softmax_lse_per_step[i],
                            seq_dim,
                        )
2746
                elif qkv_format == "thd":
2747
2748
2749
2750
2751
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2752
                        cu_seqlens_q_padded,
2753
                        False,
2754
                        softmax_lse_in_packed_format,
2755
                    )
2756
            else:
2757
                if qkv_format in ["bshd", "sbhd"]:
2758
2759
                    flash_attn_fwd_second_half_out_correction(
                        out,
2760
                        out_per_step[i],
2761
                        softmax_lse,
2762
                        softmax_lse_per_step[i],
2763
                        seq_dim,
2764
                    )
2765
                elif qkv_format == "thd":
2766
2767
2768
2769
2770
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2771
                        cu_seqlens_q_padded,
2772
                        True,
2773
                        softmax_lse_in_packed_format,
2774
                    )
2775
2776

        kv = p2p_comm_buffers[-1]
2777
2778
2779
2780
2781
2782
2783
2784
        if qkv_format == "bshd":
            out = out.view(out.shape[0], -1, *out.shape[-2:])
            ctx.batch_size = out.shape[0]
        elif qkv_format == "sbhd":
            out = out.view(-1, *out.shape[-3:])
            ctx.batch_size = out.shape[1]

        if cp_size_a2a > 1:
2785
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device)
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
            out = flash_attn_a2a_communicate(
                out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
            )
            if use_fused_attention:
                if qkv_format == "bshd":
                    # [b*s, np, hn] -> [b, s, np, hn]
                    out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                elif qkv_format == "sbhd":
                    # [s*b, np, hn] -> [s, b, np, hn]
                    out = out.view(-1, ctx.batch_size, *out.shape[-2:])
        elif not use_fused_attention:
2797
            out = out.view(-1, *out.shape[-2:])
2798

2799
2800
2801
2802
2803
        if fp8 and use_fused_attention:
            amax_cp_fwd = amax_per_step.amax(dim=1)
            S_quantizer.amax = amax_cp_fwd[0]
            O_CP_quantizer.amax = amax_cp_fwd[1]

2804
        out_fp8 = None
2805
        out_f16 = out.to(qkv_dtype)
2806

2807
        if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
2808
2809
2810
            out_fp8 = O_quantizer(out_f16)  # final result

        out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16
2811
2812

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
2813
            q_save, kv_save, out_save = q, kv, out_fp8._data
2814
        elif fp8 and is_input_fp8:
2815
            q_save, kv_save, out_save = q, kv, out_f16
2816
        else:
2817
            q_f16 = q_f16.view(q.shape)
2818
2819
            q_save, kv_save, out_save = q_f16, kv, out_f16

2820
        tensors_to_save, tensor_objects = prepare_for_saving(
2821
2822
2823
            q_save,
            kv_save,
            out_save,
2824
            softmax_lse,
2825
2826
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
2827
2828
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
2829
2830
            *rng_states,
            *attn_biases,
2831
        )
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects

        ctx.qkv_dtype = qkv_dtype
        ctx.QKV_quantizer = QKV_quantizer
        ctx.O_quantizer = O_quantizer
        ctx.O_CP_quantizer = O_CP_quantizer
        ctx.S_quantizer = S_quantizer
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer

2845
2846
2847
        ctx.cp_group_a2a = cp_group_a2a
        ctx.cp_size_a2a = cp_size_a2a
        ctx.rank_a2a = rank_a2a
2848
2849
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
2850
        ctx.cp_stream = cp_stream
2851
2852
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
2853
        ctx.max_seqlen_kv = max_seqlen_kv
2854
        ctx.softmax_scale = softmax_scale
2855
        ctx.qkv_format = qkv_format
2856
        ctx.attn_mask_type = attn_mask_type
2857
2858
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
2859
        ctx.deterministic = deterministic
2860
        ctx.use_fused_attention = use_fused_attention
2861
        ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format
2862
        ctx.second_half_lse_seqlen = second_half_lse_seqlen
2863
2864
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
2865
2866
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
2867
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
2868

2869
        return out_ret
2870
2871
2872

    @staticmethod
    def backward(ctx, dout):
2873
        # pylint: disable=missing-function-docstring
2874
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
2875
2876
2877
        cp_size_a2a = ctx.cp_size_a2a
        rank_a2a = ctx.rank_a2a

2878
2879
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
2880
2881
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
2882
2883
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

2884
        q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
2885
            restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
2886
2887
2888
2889
2890
        )
        cu_seqlens_q_per_step = other_tensors[:cp_size]
        cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
        rng_states = other_tensors[cp_size * 2 : cp_size * 3]
        attn_biases = other_tensors[cp_size * 3 : cp_size * 4]
2891

2892
2893
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
2894
2895

        seq_dim = None
2896
        if ctx.qkv_format in ["bshd", "sbhd"]:
2897
            seq_dim = ctx.qkv_format.index("s")
2898
2899
2900
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
2901

2902
        if attn_biases[0] is not None:
2903
2904
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
2905
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
2906
2907
2908
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
2909
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
2910
2911
2912
            )
        else:
            attn_dbias = None
2913
            attn_dbias_ = None
2914

2915
2916
        softmax_lse_ = None
        if causal and ctx.second_half_lse_seqlen is not None:
2917
            if ctx.qkv_format == "thd":
2918
                softmax_lse_ = tex.thd_read_second_half_lse(
2919
2920
2921
2922
                    softmax_lse,
                    cu_seqlens_q_padded,
                    ctx.softmax_lse_in_packed_format,
                    ctx.second_half_lse_seqlen,
2923
                )
2924
2925
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
2926
                softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)
2927
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
2928
2929
2930
2931
2932
2933
            if ctx.use_fused_attention:
                if ctx.softmax_lse_in_packed_format:
                    softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous()
                # [b, np, sq//2] -> [b, np, sq//2, 1] or
                # [t//2, np] -> [t//2, np, 1]
                softmax_lse_.unsqueeze_(-1)
2934
        if ctx.use_fused_attention:
2935
2936
2937
2938
            if ctx.softmax_lse_in_packed_format:
                softmax_lse = softmax_lse.transpose(0, 1).contiguous()
            # [b, np, sq] -> [b, np, sq, 1] or
            # [t, np] -> [t, np, 1]
2939
            softmax_lse.unsqueeze_(-1)
2940
            dout = dout.contiguous()
2941

2942
        dq = None
2943
        dout_dtype = dout.dtype
2944
2945
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
2946
2947
2948
        amax_per_step = None
        dP_quantizer_per_step = [None for _ in range(cp_size)]
        dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)]
2949
2950
2951
        if ctx.fp8:
            if ctx.use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
2952

2953
2954
2955
2956
2957
2958
2959
2960
2961
                dqkv_fp8_torch_dtype = get_fp8_torch_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
                dq_fp8 = torch.empty(
                    (cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device
                )
                dkv_fp8 = torch.empty(
                    (cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device
                )
2962
                dkv_fp8_ = torch.empty_like(dkv_fp8)
2963
                if ctx.is_output_fp8:
2964
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
2965
                    ctx.dO_quantizer = dout._quantizer
2966
                else:
2967
                    dout = ctx.dO_quantizer(dout)
2968
2969
                fused_attn_dqkv_dtype = dout._fp8_dtype
                dout = dout._data
2970
2971
                p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
                fp8_meta_kwargs = {}
2972
                fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
2973
2974
2975
2976
2977
2978
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
                for i in range(cp_size):
                    dP_quantizer_per_step[i] = ctx.dP_quantizer.copy()
                    dP_quantizer_per_step[i].amax = amax_per_step[0][i]
                    dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy()
                    dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i]
2979
2980
2981
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
            if ctx.fp8_meta is not None:
                if ctx.is_input_fp8:
                    q = ctx.QKV_quantizer.create_tensor_from_data(
                        q, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    kv = ctx.QKV_quantizer.create_tensor_from_data(
                        kv, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    q = q.dequantize(dtype=ctx.qkv_dtype)
                    kv = kv.dequantize(dtype=ctx.qkv_dtype)
                if ctx.is_output_fp8:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    if cp_size_a2a == 1:
                        dout = dout.dequantize(dtype=dout_dtype)
                    else:
                        ctx.dO_quantizer = dout._quantizer
                        dout = dout._data
2999
3000
3001
3002
3003
3004
3005
3006
            dq = torch.empty_like(q)
            p2p_comm_buffers = [
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            ]
            p2p_comm_buffers[0][0].copy_(kv)
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
3007
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
3008
3009
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

3010
3011
3012
3013
        if cp_size_a2a > 1:
            if not ctx.use_fused_attention:
                out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                dout = dout.view(*out.shape)
3014
3015
3016
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(
                cp_size_a2a, out.device
            )
3017
3018
3019
3020
3021
3022
3023
3024
3025
            out, dout = flash_attn_a2a_communicate(
                [out, dout],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                True,
            )
3026
            if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
3027
3028
3029
3030
                dout = ctx.dO_quantizer.create_tensor_from_data(
                    dout, fake_dtype=dout_dtype, internal=True
                )
                dout = dout.dequantize(dtype=dout_dtype)
3031

3032
3033
3034
3035
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        send_recv_reqs = []

3036
        flash_attn_bwd = None
3037
3038
3039
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
3040
3041
3042
3043
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
3044
3045
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
3046
3047
3048
3049
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
3050
3051
3052
3053
3054
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
3055
3056
                if _flash_attn_2_6_0_plus:
                    fa_backward_kwargs["softcap"] = 0.0
3057

3058
3059
3060
3061
3062
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

3063
3064
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
            if ctx.fp8:
                if i < cp_size - 1:
                    send_recv_reqs = flash_attn_p2p_communicate(
                        rank,
                        send_tensor[0],
                        send_dst,
                        recv_tensor[0],
                        recv_src,
                        ctx.cp_group,
                        batch_p2p_comm,
                    )
                else:
                    dkv_a2a_req = torch.distributed.all_to_all_single(
                        dkv_fp8,
                        dkv_fp8_,
                        group=ctx.cp_group,
                        async_op=True,
                    )
                    send_recv_reqs = [dkv_a2a_req]
            else:
                if i == 0:
                    send_tensor = send_tensor[0]
                    recv_tensor = recv_tensor[0]
                if i == (cp_size - 1):
                    send_tensor = send_tensor[1]
                    recv_tensor = recv_tensor[1]
                send_recv_reqs = flash_attn_p2p_communicate(
                    rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
                )
3094

3095
            kv = p2p_comm_buffers[i % 2][0]
3096
3097
            q_, kv_, out_, dout_ = None, None, None, None
            dq_, dk_, dv_ = None, None, None
3098
            # In reversed order of fwd
3099
            if causal:
3100
                if i == (cp_size - 1):
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        q_, kv_, out_, dout_ = q, kv, out, dout
3115
                    if ctx.use_fused_attention:
3116
3117
3118
3119
3120
3121
3122
3123
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
3124
                        if attn_dbias is not None:
3125
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
3146
                                dout_part, fake_dtype=dout_dtype, internal=True
3147
                            )
3148
3149
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
3150
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3151
                            ctx.max_seqlen_q,
3152
3153
3154
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
3155
3156
3157
3158
3159
3160
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
3161
                            fused_attn_dqkv_dtype,
3162
                            aux_ctx_tensors,
3163
                            fused_attn_backend,
3164
3165
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3166
3167
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
3168
                            qkv_layout=qkv_layout,
3169
                            attn_mask_type=ctx.attn_mask_type,
3170
                            attn_bias_type=ctx.attn_bias_type,
3171
3172
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
3173
                        )
3174
3175
3176
3177
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
3178
                    else:
3179
                        dq_ = torch.empty_like(q_)
3180
                        dkv_ = torch.empty_like(kv_)
3181
3182
3183
3184
3185
3186
3187
3188
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q,
                                ctx.max_seqlen_kv,
                            ]
3189
3190
3191
                        if _use_flash_attn_3 or (
                            _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                        ):
3192
                            fa_backward_kwargs["window_size"] = (-1, 0)
3193
3194
3195
                        elif _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = 0
3196
3197
3198
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
3199
3200
                            dout_,
                            q_,
3201
3202
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3203
3204
3205
                            out_,
                            softmax_lse,
                            dq_,
3206
3207
3208
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
3209
3210
                            causal=True,
                            **fa_backward_kwargs,
3211
                        )
3212
                elif i >= (cp_size - rank - 1):
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                        kv_ = kv[:, 0]
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                        kv_ = kv[0]
                    elif ctx.qkv_format == "thd":
                        q_, out_, dout_ = q, out, dout
                        # [2, t, np, hn] -> [2, t/2, np, hn]
                        kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
3229
                    if ctx.use_fused_attention:
3230
                        kv_ = kv_.contiguous()
3231
3232
3233
3234
3235
3236
3237
3238
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
3239
                        if attn_dbias is not None:
3240
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
3261
                                dout_part, fake_dtype=dout_dtype, internal=True
3262
                            )
3263
3264
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
3265
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3266
                            ctx.max_seqlen_q,
3267
3268
3269
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
3270
3271
3272
3273
3274
3275
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
3276
                            fused_attn_dqkv_dtype,
3277
                            aux_ctx_tensors,
3278
                            fused_attn_backend,
3279
3280
3281
3282
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
3283
3284
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
3285
                            qkv_layout=qkv_layout,
3286
                            attn_mask_type="padding" if padding else "no_mask",
3287
                            attn_bias_type=ctx.attn_bias_type,
3288
3289
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
3290
                        )
3291
3292
3293
3294
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
3295
                    else:
3296
                        dq_ = torch.empty_like(q_)
3297
                        dkv_ = torch.empty_like(kv_)
3298
3299
3300
3301
3302
3303
3304
3305
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q,
                                ctx.max_seqlen_kv // 2,
                            ]
3306
3307
3308
                        if _use_flash_attn_3 or (
                            _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                        ):
3309
                            fa_backward_kwargs["window_size"] = (-1, -1)
3310
3311
3312
                        if _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = -1
3313
3314
3315
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
3316
3317
                            dout_,
                            q_,
3318
3319
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3320
3321
3322
                            out_,
                            softmax_lse,
                            dq_,
3323
3324
3325
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
3326
3327
                            causal=False,
                            **fa_backward_kwargs,
3328
3329
                        )
                else:
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                        q_, out_, dout_ = q[1], out[1], dout[1]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        # [t, np, hn] -> [t/2, np, hn]
                        q_, out_, dout_ = [
                            tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1)
                            for x in [q, out, dout]
                        ]
                        kv_ = kv
3347
                    if ctx.use_fused_attention:
3348
                        q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]]
3349
3350
3351
3352
3353
3354
3355
3356
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse_,
                                softmax_lse_,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
3357
                        if attn_dbias is not None:
3358
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379

                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
3380
                                dout_part, fake_dtype=dout_dtype, internal=True
3381
                            )
3382
3383
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
3384
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3385
                            ctx.max_seqlen_q // 2,
3386
3387
3388
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
3389
3390
3391
3392
3393
3394
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
3395
                            fused_attn_dqkv_dtype,
3396
                            aux_ctx_tensors,
3397
                            fused_attn_backend,
3398
3399
3400
3401
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3402
3403
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
3404
                            qkv_layout=qkv_layout,
3405
                            attn_mask_type="padding" if padding else "no_mask",
3406
                            attn_bias_type=ctx.attn_bias_type,
3407
3408
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
3409
                        )
3410
3411
3412
3413
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
3414
                    else:
3415
                        dq_ = torch.empty_like(q_)
3416
                        dkv_ = torch.empty_like(kv_)
3417
                        fa_backward_args_thd = []
3418
                        if ctx.qkv_format == "thd":
3419
3420
3421
3422
3423
3424
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q // 2,
                                ctx.max_seqlen_kv,
                            ]
3425
3426
3427
                        if _use_flash_attn_3 or (
                            _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                        ):
3428
                            fa_backward_kwargs["window_size"] = (-1, -1)
3429
3430
3431
                        elif _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = -1
3432
3433
3434
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
3435
3436
                            dout_,
                            q_,
3437
3438
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
3439
3440
3441
                            out_,
                            softmax_lse_,
                            dq_,
3442
3443
3444
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
3445
3446
                            causal=False,
                            **fa_backward_kwargs,
3447
3448
3449
                        )
            else:
                if ctx.use_fused_attention:
3450
3451
3452
3453
                    if ctx.fp8:
                        aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
                    else:
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
3454
                    if attn_dbias is not None:
3455
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3456
3457
3458
3459
3460
3461
3462
3463
                    q_part = q
                    k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
                    v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
                    out_part = out
                    dout_part = dout

                    if ctx.fp8:
                        q_part = ctx.QKV_quantizer.create_tensor_from_data(
3464
                            q_part, fake_dtype=ctx.qkv_dtype, internal=True
3465
3466
                        )
                        k_part = ctx.QKV_quantizer.create_tensor_from_data(
3467
                            k_part, fake_dtype=ctx.qkv_dtype, internal=True
3468
3469
                        )
                        v_part = ctx.QKV_quantizer.create_tensor_from_data(
3470
                            v_part, fake_dtype=ctx.qkv_dtype, internal=True
3471
3472
                        )
                        out_part = ctx.O_quantizer.create_tensor_from_data(
3473
                            out_part, fake_dtype=ctx.qkv_dtype, internal=True
3474
3475
                        )
                        dout_part = ctx.dO_quantizer.create_tensor_from_data(
3476
                            dout_part, fake_dtype=dout_dtype, internal=True
3477
                        )
3478
3479
                        fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                        fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
3480
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3481
                        ctx.max_seqlen_q,
3482
3483
3484
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
3485
3486
3487
3488
3489
3490
                        q_part,
                        k_part,
                        v_part,
                        out_part,
                        dout_part,
                        ctx.qkv_dtype,
3491
                        fused_attn_dqkv_dtype,
3492
                        aux_ctx_tensors,
3493
                        fused_attn_backend,
3494
3495
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3496
3497
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
3498
                        qkv_layout=qkv_layout,
3499
                        attn_mask_type=ctx.attn_mask_type,
3500
                        attn_bias_type=ctx.attn_bias_type,
3501
3502
                        deterministic=ctx.deterministic,
                        **fp8_meta_kwargs,
3503
                    )
3504
3505
3506
3507
3508
3509

                    if ctx.fp8:
                        dq_ = dq_._data
                        dk_ = dk_._data
                        dv_ = dv_._data

3510
                else:
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
                    dq_ = torch.empty_like(q)
                    dkv_ = torch.empty_like(kv)
                    fa_backward_args_thd = []
                    if ctx.qkv_format == "thd":
                        fa_backward_args_thd = [
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_kv,
                        ]
3521
                    if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
3522
                        fa_backward_kwargs["window_size"] = (-1, -1)
3523
3524
3525
                    elif _flash_attn_2_7_0_plus:
                        fa_backward_kwargs["window_size_left"] = -1
                        fa_backward_kwargs["window_size_right"] = -1
3526
3527
3528
                    if not _use_flash_attn_3:
                        fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                    flash_attn_bwd(
3529
3530
3531
3532
3533
                        dout,
                        q,
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
                        out,
3534
3535
                        softmax_lse,
                        dq_,
3536
3537
3538
                        dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                        dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                        *fa_backward_args_thd,
3539
3540
                        causal=False,
                        **fa_backward_kwargs,
3541
3542
                    )

3543
3544
            if ctx.fp8:
                dq = dq_fp8[(rank + i + 1) % cp_size]
3545
3546
3547
            if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1):
                # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or
                # [sq, b, np, hn] -> [2, sq//2, b, np, hn]
3548
                dq_ = dq_.view(*dq.shape)
3549

3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
            if ctx.fp8:
                if i >= (cp_size - rank - 1) or not causal:
                    dq.copy_(dq_)
                else:
                    if ctx.qkv_format == "bshd":
                        dq[:, 0, ...].fill_(0)
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[0].fill_(0)
                        dq[1].copy_(dq_)
            elif causal:
3561
                if i > (cp_size - rank - 1):
3562
                    dq.add_(dq_)
3563
3564
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
3565
3566
                        dq.copy_(dq_)
                    else:
3567
3568
3569
3570
3571
3572
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
3573
                        elif ctx.qkv_format == "thd":
3574
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
3575
                elif i > 0:
3576
3577
3578
3579
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
3580
                    elif ctx.qkv_format == "thd":
3581
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
3582
                else:
3583
3584
3585
3586
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
3587
                    elif ctx.qkv_format == "thd":
3588
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
3589
3590
3591
3592
3593
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
3594

3595
            if attn_dbias is not None:
3596
                idx = (rank + i + 1) % cp_size
3597
                if i == (cp_size - 1) or not causal:
3598
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
3599
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
3600
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
3601
3602
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
3603
3604
3605
3606
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
3607
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
3608
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
3609
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
3610

3611
3612
3613
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
3614

3615
3616
3617
3618
3619
3620
3621
            if ctx.fp8:
                if i < cp_size - 1:
                    dkv = dkv_fp8_[(rank + i + 1) % cp_size]
                else:
                    dkv = dkv_fp8[(rank + i + 1) % cp_size]
            else:
                dkv = p2p_comm_buffers[(i + 1) % 2][1]
3622
            if ctx.use_fused_attention:
3623
                if ctx.qkv_format in ["bshd", "sbhd"]:
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
                    dkv_ = _combine_tensors([dk_, dv_], -2)
                elif ctx.qkv_format == "thd":
                    dkv_ = torch.cat(
                        (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
                    )  # pylint: disable=used-before-assignment
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
                dkv_ = dkv_.movedim(-3, 0)
                if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
                    # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(*dkv.shape)
3638

3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
            if ctx.fp8:
                if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
                    if ctx.qkv_format == "bshd":
                        dkv[:, :, 0, ...].copy_(dkv_)
                        dkv[:, :, 1, ...].fill_(0)
                    elif ctx.qkv_format == "sbhd":
                        dkv[:, 0, ...].copy_(dkv_)
                        dkv[:, 1, ...].fill_(0)
                else:
                    dkv.copy_(dkv_)
            elif causal:
3650
                if i == (cp_size - 1):
3651
                    if rank == 0:
3652
3653
3654
3655
3656
3657
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
3658
                        elif ctx.qkv_format == "thd":
3659
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
3660
3661
                    else:
                        dkv.add_(dkv_)
3662
3663
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
3664
3665
3666
3667
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
3668
                        elif ctx.qkv_format == "thd":
3669
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
3670
                    else:
3671
3672
3673
3674
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
3675
                        elif ctx.qkv_format == "thd":
3676
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
3677
3678
3679
3680
3681
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
3682
3683
3684
3685
3686
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

3687
        if ctx.fp8 and ctx.use_fused_attention:
3688
3689
3690
            amax_cp_bwd = amax_per_step.amax(dim=1)
            ctx.dP_quantizer.amax = amax_cp_bwd[0]
            ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1]
3691
3692
3693
3694
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
                # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
                dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
3695
3696
3697
3698
3699
3700
3701
            dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
                dq_fp8, fake_dtype=torch.float32, internal=True
            )
            dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
                dkv_fp8, fake_dtype=torch.float32, internal=True
            )
            dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]]
3702
3703
            dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]

3704
        if causal:
3705
3706
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
3707
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
3708
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
3709
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
3710
3711
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
3712
                dq = dq.view(-1, *dq.shape[-3:])
3713
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
3714
3715
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

3716
3717
3718
        if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
            dq[cu_seqlens_q_padded[-1] :].fill_(0)
            dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
3719

3720
        if ctx.fp8 and ctx.is_input_fp8:
3721
3722
            assert torch.uint8 not in [dq.dtype, dkv.dtype]
            dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]]
3723
3724
3725
        dk, dv = dkv[0], dkv[1]

        if cp_size_a2a > 1:
3726
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device)
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
            dq, dk, dv = flash_attn_a2a_communicate(
                [dq, dk, dv],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                False,
            )
            if ctx.qkv_format == "bshd":
                dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
            elif ctx.qkv_format == "sbhd":
                dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

3741
3742
3743
        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
3744
3745
        # converting torch.uint8 to float8tensor
        if ctx.fp8 and ctx.is_input_fp8:
3746
3747
3748
            dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype)
            dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype)
            dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
3749
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
3750

3751
3752
3753
        return (
            None,
            dq,
3754
3755
            dk,
            dv,
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3767
            attn_dbias,
3768
3769
3770
3771
3772
            None,
            None,
            None,
            None,
            None,
3773
3774
            None,
            None,
3775
            None,
3776
            None,
3777
        )
3778
3779


3780
3781
def get_kv_seq_info_after_all_gather(
    local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
3782
):
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
    """Compute KV sequence index range and update window size after all-gather."""
    local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
    full_seq_end_idx = max_seqlen_kv * cp_size * 2

    if window_size is None:
        window_size = (-1, 0) if causal else (-1, -1)

    if window_size[1] == -1:
        seq_end_idx = full_seq_end_idx
        window_size_right = -1
    else:
        seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
        window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx

    if window_size[0] == -1:
        seq_start_idx = 0
        window_size_left = -1
    else:
        seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
        window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx

    return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
3805
3806
3807
3808


class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    """
3809
3810
    Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
    Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
3833
3834
        cp_group,
        cp_stream,
3835
    ):
3836
        # pylint: disable=missing-function-docstring
3837
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
3838
3839
3840
3841
3842
3843
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)

3844
3845
        qkv_dtype = q.dtype

3846
3847
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
3848
        assert not padding, f"{attn_mask_type} mask type is not supported!"
3849
3850
3851
3852
3853
3854
3855
        if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
            attn_mask_type = attn_mask_type + "_bottom_right"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            use_fused_attention or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
3856

3857
        flash_attn_fwd = None
3858
3859
3860
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
3861
3862
3863
3864
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
3865
            else:
3866
3867
3868
3869
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
3870
3871
3872
3873
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
3874
                if _flash_attn_2_5_7_plus and qkv_format == "thd":
3875
                    fa_forward_kwargs["block_table"] = None
3876
3877
                if _flash_attn_2_6_0_plus:
                    fa_forward_kwargs["softcap"] = 0.0
3878
3879
3880
3881
3882
3883
3884
3885
3886
3887
3888

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        max_seqlen_q = max_seqlen_q // (2 * cp_size)
        max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
3889
3890
        if use_fused_attention or qkv_format == "thd":
            cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
3891
3892
3893
3894
        if cu_seqlens_q_padded is not None and qkv_format == "thd":
            cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
        else:
            cu_seqlens_q_padded = None
3895

3896
3897
3898
3899
        # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
        q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
        # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
        k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
3900

3901
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3902
3903
        k_ag, _ = gather_along_first_dim(k, cp_group)
        v_ag, _ = gather_along_first_dim(v, cp_group)
3904
3905

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3906
3907
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3908
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
3909
3910
3911
3912
3913
3914
3915
3916
3917
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        cp_stream.wait_stream(torch.cuda.current_stream())

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
3918
3919

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
3920
3921
3922
        kv_seq_range_per_step = [None, None]
        window_size_per_step = [None, None]
        cu_seqlens_kv_per_step = [None, None]
3923
3924
3925
3926
3927
3928
3929
3930
        out_per_step = [None, None]
        softmax_lse_per_step = [None, None]
        rng_states = [None, None]
        out = torch.empty_like(q)

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3931
3932
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3933
3934
3935
3936
3937
3938
3939
3940
3941
                    q_ = q.select(seq_dim, i).contiguous()
                    kv_seq_range_per_step[i], window_size_per_step[i] = (
                        get_kv_seq_info_after_all_gather(
                            local_seq_chunk_ids[i],
                            cp_size,
                            max_seqlen_q,
                            max_seqlen_kv,
                            window_size,
                            causal,
3942
                        )
3943
3944
3945
3946
3947
3948
                    )
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv_ = seq_end_idx - seq_start_idx
3949
3950
3951
3952
                    if use_fused_attention or qkv_format == "thd":
                        cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
                            k.shape[1], max_seqlen_kv_, k.device
                        )
3953
3954
3955
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
3956
3957
3958
3959
                    if use_fused_attention:
                        out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
                            is_training,
                            max_seqlen_q,
3960
                            max_seqlen_kv_,
3961
                            cu_seqlens_q,
3962
                            cu_seqlens_kv_per_step[i],
3963
3964
3965
                            q_,
                            k_,
                            v_,
3966
                            qkv_dtype,
3967
3968
3969
3970
3971
3972
3973
3974
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=softmax_scale,
                            dropout=dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=attn_mask_type,
                            attn_bias_type=attn_bias_type,
                            attn_bias=attn_bias,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
3975
3976
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
                            window_size=window_size_per_step[i],
3977
3978
                        )
                    else:
3979
3980
3981
3982
3983
3984
3985
3986
                        fa_forward_args_thd = []
                        if qkv_format == "thd":
                            fa_forward_args_thd = [
                                cu_seqlens_q,
                                cu_seqlens_kv_per_step[i],
                                max_seqlen_q,
                                max_seqlen_kv_,
                            ]
3987
3988
3989
3990
3991
3992
3993
                        if _use_flash_attn_3 or (
                            _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus
                        ):
                            fa_forward_kwargs["window_size"] = window_size_per_step[i]
                        elif _flash_attn_2_7_0_plus:
                            fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
                            fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1]
3994
3995
3996
3997
                        fa_outputs = flash_attn_fwd(
                            q_,
                            k_,
                            v_,
3998
                            *fa_forward_args_thd,
3999
4000
                            causal=causal,
                            **fa_forward_kwargs,
4001
                        )
4002
4003
4004
4005
4006
4007
4008
4009
4010
4011
                        if not _flash_attn_2_7_0_plus:
                            out_per_step[i] = fa_outputs[4]
                            softmax_lse_per_step[i] = fa_outputs[5]
                            if not _use_flash_attn_3:
                                rng_states[i] = fa_outputs[7]
                        else:
                            out_per_step[i] = fa_outputs[0]
                            softmax_lse_per_step[i] = fa_outputs[1]
                            if not _use_flash_attn_3:
                                rng_states[i] = fa_outputs[3]
4012
4013
4014
4015

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if qkv_format == "bshd":
4016
                        out[:, i - 1].copy_(out_per_step[i - 1])
4017
                    elif qkv_format == "sbhd":
4018
                        out[i - 1].copy_(out_per_step[i - 1])
4019
4020
4021
4022
4023
4024
4025
4026
4027
4028
4029
4030
4031
4032
4033
4034
4035

        torch.cuda.current_stream().wait_stream(cp_stream)

        if use_fused_attention:
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
        else:
            out = out.view(-1, *out.shape[-2:])

        ctx.save_for_backward(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_q_padded,
4036
            *cu_seqlens_kv_per_step,
4037
4038
4039
4040
            *out_per_step,
            *softmax_lse_per_step,
            *rng_states,
        )
4041
4042

        ctx.qkv_dtype = qkv_dtype
4043
4044
        ctx.kv_seq_range_per_step = kv_seq_range_per_step
        ctx.window_size_per_step = window_size_per_step
4045
4046
4047
4048
4049
4050
4051
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_bias_type = attn_bias_type
4052
        ctx.attn_mask_type = attn_mask_type
4053
4054
        ctx.deterministic = deterministic
        ctx.use_fused_attention = use_fused_attention
4055
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
4056
4057
4058
4059
        return out

    @staticmethod
    def backward(ctx, dout):
4060
        # pylint: disable=missing-function-docstring
4061
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
4062
4063
4064
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)

4065
4066
4067
4068
4069
4070
        (*saved_tensors,) = ctx.saved_tensors
        (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
        cu_seqlens_kv_per_step = saved_tensors[5:7]
        out_per_step = saved_tensors[7:9]
        softmax_lse_per_step = saved_tensors[9:11]
        rng_states = saved_tensors[11:13]
4071
4072
        kv_seq_range_per_step = ctx.kv_seq_range_per_step
        window_size_per_step = ctx.window_size_per_step
4073

4074
        seq_dim = ctx.qkv_format.index("s")
4075
4076
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

4077
        dout = dout.view(q.shape)
4078
        dq = torch.empty_like(q)
4079
        dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
4080
4081
4082
4083
4084
4085
4086
4087
4088
4089
        dv = torch.zeros_like(dk)
        dq_per_step = [None, None]
        dk_per_step = [None, None]
        dv_per_step = [None, None]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
        # synchronize dkv update across steps
        dkv_update_done = torch.cuda.Event()

4090
        # [s, b, np, hn] -> [cp, s, b, np, hn]
4091
4092
        k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
        v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
4093
4094

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
4095
4096
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
4097
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
4098
4099
4100
4101
4102
4103
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        ctx.cp_stream.wait_stream(torch.cuda.current_stream())
4104
4105
4106

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]

4107
        flash_attn_bwd = None
4108
4109
4110
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
4111
4112
4113
4114
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
4115
4116
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
4117
4118
4119
4120
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
4121
4122
4123
4124
4125
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
4126
4127
                if _flash_attn_2_6_0_plus:
                    fa_backward_kwargs["softcap"] = 0.0
4128
4129
4130
4131

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
4132
4133
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
4134
4135
4136
4137
4138
4139
4140
4141
4142
                    q_ = q.select(seq_dim, i).contiguous()
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv = seq_end_idx - seq_start_idx
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
4143
                    out_ = out_per_step[i]
4144
                    dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
4145
4146
4147
4148
                    if ctx.use_fused_attention:
                        aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
                            ctx.max_seqlen_q,
4149
                            max_seqlen_kv,
4150
                            cu_seqlens_q,
4151
                            cu_seqlens_kv_per_step[i],
4152
4153
4154
4155
4156
                            q_,
                            k_,
                            v_,
                            out_,
                            dout_,
4157
                            ctx.qkv_dtype,
4158
                            TE_DType[dout.dtype],
4159
4160
4161
                            aux_ctx_tensors,
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
4162
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
4163
4164
4165
4166
4167
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=ctx.attn_mask_type,
                            attn_bias_type=ctx.attn_bias_type,
4168
4169
                            window_size=window_size_per_step[i],
                            deterministic=ctx.deterministic,
4170
4171
4172
4173
4174
                        )
                    else:
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
4175
4176
4177
4178
4179
4180
4181
4182
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q,
                                cu_seqlens_kv_per_step[i],
                                ctx.max_seqlen_q,
                                max_seqlen_kv,
                            ]
4183
4184
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[i]
4185
4186
4187
4188
4189
                        if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size"] = window_size_per_step[i]
                        if _flash_attn_2_7_0_plus:
                            fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
                            fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
4190
                        flash_attn_bwd(
4191
4192
4193
4194
4195
4196
4197
4198
4199
                            dout_,
                            q_,
                            k_,
                            v_,
                            out_,
                            softmax_lse_per_step[i],
                            dq_per_step[i],
                            dk_per_step[i],
                            dv_per_step[i],
4200
                            *fa_backward_args_thd,
4201
4202
                            causal="causal" in ctx.attn_mask_type,
                            **fa_backward_kwargs,
4203
4204
4205
4206
4207
                        )

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if ctx.qkv_format == "bshd":
4208
                        dq[:, i - 1].copy_(dq_per_step[i - 1])
4209
                    elif ctx.qkv_format == "sbhd":
4210
4211
4212
4213
4214
4215
                        dq[i - 1].copy_(dq_per_step[i - 1])
                    # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
                    dk_per_step[i - 1], dv_per_step[i - 1] = [
                        x.movedim(seq_dim, 0).contiguous()
                        for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
                    ]
4216
4217
4218
                    # wait until dkv update of last step is done
                    if i > 1:
                        flash_attn_streams[i - 1].wait_event(dkv_update_done)
4219
4220
4221
4222
4223
4224
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i - 1][0],
                        kv_seq_range_per_step[i - 1][1],
                    )
                    dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
                    dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
4225
4226
4227
4228
4229
                    if i < len(local_seq_chunk_ids):
                        flash_attn_streams[i - 1].record_event(dkv_update_done)

        torch.cuda.current_stream().wait_stream(ctx.cp_stream)

4230
4231
4232
        # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
        dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
        dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
4233
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device)
4234
4235
4236
        dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
        dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
4237
4238
4239
4240
4241
        dk = dk.view(-1, *dk.shape[-3:])
        dv = dv.view(-1, *dv.shape[-3:])
        dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
        dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)

4242
4243
4244
        dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
        dk = dk.movedim(0, seq_dim).contiguous()
        dv = dv.movedim(0, seq_dim).contiguous()
4245
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
4246
4247
4248
4249
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
4267
4268
4269
4270
4271
4272
4273
4274
4275
4276
4277
4278
4279
4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
4299
4300
4301

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
    """
    Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
    Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
        fp8,
        fp8_meta,
        cp_group,
        cp_stream,
4302
        quantizers,
4303
    ):
4304
        # pylint: disable=missing-function-docstring
4305
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
4306
4307
4308
4309
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
4310
        qkv_dtype = q.dtype
4311
4312
4313
4314
4315
4316
4317
4318
4319
4320
4321
4322

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
        assert not padding, f"{attn_mask_type} mask type is not supported!"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            window_size == (-1, 0)
            or window_size == (-1, -1)
            or use_fused_attention
            or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
4323

4324
        flash_attn_fwd = None
4325
4326
4327
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
4328
4329
4330
4331
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
4332
4333
                fa_forward_kwargs["window_size"] = window_size
            else:
4334
4335
4336
4337
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
4338
4339
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
4340
                if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
4341
                    fa_forward_kwargs["window_size"] = window_size
4342
4343
4344
                elif _flash_attn_2_7_0_plus:
                    fa_forward_kwargs["window_size_left"] = window_size[0]
                    fa_forward_kwargs["window_size_right"] = window_size[1]
4345
4346
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
4347
                if _flash_attn_2_5_7_plus and qkv_format == "thd":
4348
                    fa_forward_kwargs["block_table"] = None
4349
4350
                if _flash_attn_2_6_0_plus:
                    fa_forward_kwargs["softcap"] = 0.0
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
4361
4362
4363
4364

        assert (
            q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
        ), "The number of attention heads needs to be divisible by CP size!"

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        batch_dim = qkv_format.index("b")
        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

4365
        fused_attn_backend = None
4366
4367
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
4368
4369
4370
4371
4372
4373
4374
        is_output_fp8 = False

        QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
            get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
        )
        if fp8:
            if use_fused_attention:
4375
                fused_attn_backend = FusedAttnBackend["FP8"]
4376
4377
4378
4379
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
4380
                is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
4381
                if is_input_fp8:
4382
                    QKV_quantizer = q._quantizer
4383
4384
4385
4386
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                    q_f16, k_f16, v_f16 = q, k, v
4387
                    q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]]
4388
                fp8_meta_kwargs = {}
4389
4390
                fp8_meta_kwargs["s_quantizer"] = S_quantizer
                fp8_meta_kwargs["o_quantizer"] = O_quantizer  # partial result quantizer
4391
4392
4393
4394
4395
4396
4397
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

4398
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device)
4399
4400
4401
4402
        q, k, v = flash_attn_a2a_communicate(
            [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
        )

4403
        if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
4404
            q_f16, k_f16, v_f16 = q, k, v
4405
            q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]]
4406
4407
4408

        batch_size = q.shape[batch_dim]
        if use_fused_attention:
4409
4410
4411
4412
4413
4414
4415
4416
4417
4418
4419
            q_part, k_part, v_part = q, k, v
            if fp8:
                q_part = QKV_quantizer.create_tensor_from_data(
                    q, fake_dtype=qkv_dtype, internal=True
                )
                k_part = QKV_quantizer.create_tensor_from_data(
                    k, fake_dtype=qkv_dtype, internal=True
                )
                v_part = QKV_quantizer.create_tensor_from_data(
                    v, fake_dtype=qkv_dtype, internal=True
                )
4420
4421
4422
4423
4424
4425
            out, aux_ctx_tensors = fused_attn_fwd(
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
4426
4427
4428
4429
                q_part,
                k_part,
                v_part,
                qkv_dtype,
4430
4431
4432
4433
4434
4435
4436
4437
4438
4439
4440
4441
                fused_attn_backend,
                attn_scale=softmax_scale,
                dropout=dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                attn_bias=attn_bias,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                window_size=window_size,
                **fp8_meta_kwargs,
            )
4442
4443
            if fp8:
                out = out._data
4444
        else:
4445
4446
4447
4448
4449
4450
4451
4452
            fa_forward_args_thd = []
            if qkv_format == "thd":
                fa_forward_args_thd = [
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                ]
4453
            fa_outputs = flash_attn_fwd(
4454
4455
4456
                q,
                k,
                v,
4457
                *fa_forward_args_thd,
4458
                causal=causal,
4459
                **fa_forward_kwargs,
4460
            )
4461
4462
4463
4464
4465
4466
            if not _flash_attn_2_7_0_plus:
                out, softmax_lse = fa_outputs[4], fa_outputs[5]
                rng_state = fa_outputs[7] if not _use_flash_attn_3 else None
            else:
                out, softmax_lse = fa_outputs[0], fa_outputs[1]
                rng_state = fa_outputs[3] if not _use_flash_attn_3 else None
4467
4468
            aux_ctx_tensors = [softmax_lse, rng_state]

4469
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
4470
4471
4472
4473
4474
4475
4476
4477
4478
4479
4480
4481
4482
        out = flash_attn_a2a_communicate(
            out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
        )

        if use_fused_attention:
            if qkv_format == "bshd":
                # [b*s, np, hn] -> [b, s, np, hn]
                out = out.view(batch_size, -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                # [s*b, np, hn] -> [s, b, np, hn]
                out = out.view(-1, batch_size, *out.shape[-2:])

        if fp8:
4483
            if is_output_fp8:
4484
4485
                out_fp8 = O_quantizer.create_tensor_from_data(
                    out, fake_dtype=qkv_dtype, internal=False
4486
4487
                )
                out_ret = out_fp8
4488
                out = out_fp8._data
4489
            else:
4490
                out_fp8 = O_quantizer.create_tensor_from_data(
4491
                    out, fake_dtype=qkv_dtype, internal=True
4492
                )
4493
                out_f16 = out_fp8.dequantize(dtype=qkv_dtype)
4494
4495
4496
4497
                out_ret = out_f16
        else:
            out_ret = out

4498
        if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
4499
            q_save, k_save, v_save, out_save = q, k, v, out
4500
4501
4502
4503
4504
4505
4506
4507
4508
        else:
            if is_input_fp8:
                q_save, k_save, v_save = q, k, v
            else:
                q_save, k_save, v_save = q_f16, k_f16, v_f16
            if is_output_fp8:
                out_save = out
            else:
                out_save = out_f16
4509

4510
        tensors_to_save, tensor_objects = prepare_for_saving(
4511
4512
4513
4514
4515
4516
4517
4518
4519
4520
            q_save,
            k_save,
            v_save,
            out_save,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *aux_ctx_tensors,
        )
4521
4522
4523
4524
4525
4526
4527
4528
4529
4530
4531
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects

        ctx.qkv_dtype = qkv_dtype
        ctx.QKV_quantizer = QKV_quantizer
        ctx.O_quantizer = O_quantizer
        ctx.S_quantizer = S_quantizer
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer

4532
4533
4534
4535
4536
4537
4538
4539
4540
4541
4542
4543
4544
4545
4546
        ctx.batch_size = batch_size
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_mask_type = attn_mask_type
        ctx.attn_bias_type = attn_bias_type
        ctx.deterministic = deterministic
        ctx.window_size = window_size
        ctx.use_fused_attention = use_fused_attention
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
4547
4548
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
4549
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
4550
4551
4552
4553
        return out_ret

    @staticmethod
    def backward(ctx, dout):
4554
        # pylint: disable=missing-function-docstring
4555
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
4556
4557
        cp_size = get_distributed_world_size(ctx.cp_group)

4558
4559
4560
4561
4562
4563
4564
4565
4566
4567
4568
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *aux_ctx_tensors,
        ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
4569
4570
4571
4572
4573

        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
        causal = "causal" in ctx.attn_mask_type
        seq_dim = ctx.qkv_format.index("s")

4574
        dout_dtype = dout.dtype
4575
4576
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
4577
4578
4579
        if ctx.fp8:
            if ctx.use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
4580
                if ctx.is_output_fp8:
4581
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
4582
                    ctx.dO_quantizer = dout._quantizer
4583
                else:
4584
4585
4586
                    dout = ctx.dO_quantizer(dout)
                fused_attn_dqkv_dtype = dout._fp8_dtype
                dout = dout._data
4587
                fp8_meta_kwargs = {}
4588
4589
4590
4591
                fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
                fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer
                fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer

4592
4593
4594
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
4595
4596
4597
4598
4599
4600
4601
4602
4603
4604
4605
4606
4607
4608
4609
4610
            if ctx.fp8_meta is not None:
                if ctx.is_output_fp8:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.dO_quantizer = dout._quantizer
                    dout = dout._data
                if ctx.is_input_fp8:
                    q = ctx.QKV_quantizer.create_tensor_from_data(
                        q, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    k = ctx.QKV_quantizer.create_tensor_from_data(
                        k, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    v = ctx.QKV_quantizer.create_tensor_from_data(
                        v, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]]
4611
4612
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
4613
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
4614
4615
4616
4617
4618
4619
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if not ctx.use_fused_attention:
            out = out.view(ctx.batch_size, -1, *out.shape[-2:])
        dout = dout.view(*out.shape)

4620
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device)
4621
4622
4623
        out, dout = flash_attn_a2a_communicate(
            [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
        )
4624
4625
4626
4627
4628
4629
4630
4631
4632
        if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
            out = ctx.O_quantizer.create_tensor_from_data(
                out, fake_dtype=ctx.qkv_dtype, internal=True
            )
            dout = ctx.dO_quantizer.create_tensor_from_data(
                dout, fake_dtype=dout_dtype, internal=True
            )
            out = out.dequantize(dtype=ctx.qkv_dtype)
            dout = dout.dequantize(dtype=dout_dtype)
4633

4634
        flash_attn_bwd = None
4635
4636
4637
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
4638
4639
4640
4641
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
4642
4643
4644
                fa_backward_kwargs["window_size"] = ctx.window_size
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
4645
4646
4647
4648
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
4649
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
4650
                if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus):
4651
                    fa_backward_kwargs["window_size"] = ctx.window_size
4652
4653
4654
                elif _flash_attn_2_7_0_plus:
                    fa_backward_kwargs["window_size_left"] = ctx.window_size[0]
                    fa_backward_kwargs["window_size_right"] = ctx.window_size[1]
4655
4656
4657
4658
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
4659
4660
                if _flash_attn_2_6_0_plus:
                    fa_backward_kwargs["softcap"] = 0.0
4661
4662

        if ctx.use_fused_attention:
4663
4664
4665
4666
4667
4668
4669
4670
4671
4672
4673
4674
4675
4676
4677
4678
4679
4680
4681
4682
            q_part = q
            k_part = k
            v_part = v
            out_part = out
            dout_part = dout

            if ctx.fp8:
                q_part = ctx.QKV_quantizer.create_tensor_from_data(
                    q_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                k_part = ctx.QKV_quantizer.create_tensor_from_data(
                    k_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                v_part = ctx.QKV_quantizer.create_tensor_from_data(
                    v_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                out_part = ctx.O_quantizer.create_tensor_from_data(
                    out_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                dout_part = ctx.dO_quantizer.create_tensor_from_data(
4683
                    dout_part, fake_dtype=dout_dtype, internal=True
4684
4685
                )

4686
4687
4688
4689
4690
            dq, dk, dv, _ = fused_attn_bwd(
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
4691
4692
4693
4694
4695
4696
                q_part,
                k_part,
                v_part,
                out_part,
                dout_part,
                ctx.qkv_dtype,
4697
4698
4699
4700
4701
4702
4703
4704
4705
4706
4707
4708
4709
4710
                fused_attn_dqkv_dtype,
                aux_ctx_tensors,
                fused_attn_backend,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                attn_scale=ctx.softmax_scale,
                dropout=ctx.dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=ctx.attn_mask_type,
                attn_bias_type=ctx.attn_bias_type,
                window_size=ctx.window_size,
                deterministic=ctx.deterministic,
                **fp8_meta_kwargs,
            )
4711
4712
4713
4714
            if ctx.fp8:
                dq = dq._data
                dk = dk._data
                dv = dv._data
4715
4716
4717
        else:
            softmax_lse, rng_state = aux_ctx_tensors
            dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
4718
4719
4720
4721
4722
4723
4724
4725
            fa_backward_args_thd = []
            if ctx.qkv_format == "thd":
                fa_backward_args_thd = [
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    ctx.max_seqlen_q,
                    ctx.max_seqlen_kv,
                ]
4726
4727
4728
            if not _use_flash_attn_3:
                fa_backward_kwargs["rng_state"] = rng_state
            flash_attn_bwd(
4729
4730
4731
4732
4733
4734
4735
4736
4737
                dout,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
4738
                *fa_backward_args_thd,
4739
4740
                causal=causal,
                **fa_backward_kwargs,
4741
4742
            )

4743
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device)
4744
4745
4746
4747
        dq, dk, dv = flash_attn_a2a_communicate(
            [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
        )

4748
        if ctx.qkv_format == "bshd":
4749
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
4750
        elif ctx.qkv_format == "sbhd":
4751
4752
4753
            dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8:
4754
4755
4756
4757
4758
4759
4760
4761
4762
            dq = ctx.dQKV_quantizer.create_tensor_from_data(
                dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
            dk = ctx.dQKV_quantizer.create_tensor_from_data(
                dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
            dv = ctx.dQKV_quantizer.create_tensor_from_data(
                dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
4763
            if not ctx.is_input_fp8:
4764
                dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]]
4765
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
4766
4767
4768
4769
4770
4771
4772
4773
4774
4775
4776
4777
4778
4779
4780
4781
4782
4783
4784
4785
4786
4787
4788

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4789
4790
4791
            None,
            None,
            None,
4792
            None,
4793
4794
4795
        )


4796
def attn_forward_func_with_cp(
4797
4798
4799
4800
4801
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
4802
    cu_seqlens_kv,
4803
    max_seqlen_q,
4804
    max_seqlen_kv,
4805
4806
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
4807
4808
4809
4810
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
4811
    cp_comm_type,
4812
4813
4814
4815
4816
4817
4818
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
4819
    window_size=None,
4820
4821
    fp8=False,
    fp8_meta=None,
4822
    quantizers=None,
4823
    pad_between_seqs=False,
4824
) -> torch.Tensor:
4825
4826
4827
4828
    """
    Attention implementation with context parallelism.
    """

4829
4830
4831
4832
4833
4834
4835
4836
4837
4838
4839
4840
4841
4842
4843
4844
    if cp_comm_type == "a2a+p2p":
        assert isinstance(
            cp_group, list
        ), "Hierarchical CP implementation needs multi-level CP groups!"
        assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
        if get_distributed_world_size(cp_group[0]) == 1:
            cp_group = cp_group[1]
            cp_comm_type = "p2p"
        elif get_distributed_world_size(cp_group[1]) == 1:
            cp_group = cp_group[0]
            cp_comm_type = "a2a"
    else:
        assert isinstance(
            cp_group, dist_group_type
        ), f"Unsupported process group for CP communication type {cp_comm_type}!"

4845
4846
4847
4848
4849
4850
4851
4852
4853
4854
4855
4856
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
4857
    assert qkv_format != "thd" or (
4858
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
4859
    ), "cu_seqlens_padded cannot be None with context parallelism + THD format!"
4860
4861
4862

    sliding_window_attn = (
        window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
4863
    )
4864
4865
4866
4867
    assert not sliding_window_attn or cp_comm_type in [
        "a2a",
        "all_gather",
    ], "The context parallel running configs cannot support sliding window attetnion!"
4868

4869
4870
4871
4872
4873
4874
4875
4876
4877
4878
4879
4880
4881
4882
4883
4884
4885
4886
4887
4888
4889
    args = [
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ]

4890
    if cp_comm_type in ["p2p", "a2a+p2p"]:
4891
        args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, pad_between_seqs]
4892
4893
4894
4895
4896
4897
4898
        out = AttnFuncWithCPAndKVP2P.apply(*args)
    elif cp_comm_type == "all_gather":
        args.pop(5)
        args.pop(8)
        args += [window_size, cp_group, cp_stream]
        out = AttnFuncWithCPAndKVAllGather.apply(*args)
    elif cp_comm_type == "a2a":
4899
        args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers]
4900
        out = AttnFuncWithCPAndQKVOA2A.apply(*args)
4901
4902
4903
    else:
        raise ValueError(f"Unsupported communication type: {cp_comm_type}!")

4904
4905
4906
    return out


4907
4908
4909
4910
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
4911

4912
4913
4914
    def __init__(
        self,
        dim: int,
4915
        rotary_percent: float = 1.0,
4916
4917
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
4918
        rotary_base: float = 10000.0,
4919
4920
4921
4922
4923
4924
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
4925
4926
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
4927
4928
4929
4930
4931
4932
4933
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
4934
4935
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
4936
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
4937
        self.rotary_base = rotary_base
4938
        inv_freq = 1.0 / (
4939
            self.rotary_base
4940
4941
4942
4943
4944
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
4945
        self.register_buffer("inv_freq", inv_freq)
4946
4947
4948
4949
4950
4951
4952
4953
4954
4955
4956
4957
4958
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
4959
4960
4961
4962
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
4963

4964
4965
4966
4967
4968
4969
4970
4971
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
4972
4973
4974
4975
4976
4977
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

4978
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
4979
4980
4981
4982
4983
4984
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

4985
4986
4987
4988
4989
4990
4991
4992
4993
4994
4995
4996
4997
4998
4999
5000
5001

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
5002
5003
        cp_size: int = 1,
        cp_rank: int = 0,
5004
    ) -> torch.Tensor:
5005
        # pylint: disable=missing-function-docstring
5006
5007
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
5008
5009
5010
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
5011
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
5012
        elif tensor_format == "thd":
5013
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
5014
5015
5016
5017
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format
5018
5019
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
5020
5021
5022
5023

        return output

    @staticmethod
5024
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
5025
        # pylint: disable=missing-function-docstring
5026
5027
5028
5029
5030
5031
5032
5033
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
5034
5035
5036
            grad_input = tex.fused_rope_thd_backward(
                grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
            )
5037
5038
5039
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

5040
        return grad_input, None, None, None, None, None
5041
5042


5043
5044
5045
5046
5047
5048
5049
5050
5051
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


5052
def apply_rotary_pos_emb(
5053
5054
5055
5056
5057
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
5058
5059
    cp_size: int = 1,
    cp_rank: int = 0,
5060
) -> torch.Tensor:
5061
    """
5062
    Apply rotary positional embedding tensor to the input tensor.
5063

5064
5065
5066
    Parameters
    ----------
    t: torch.Tensor
5067
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
5068
5069
5070
5071
5072
5073
5074
5075
5076
5077
5078
5079
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
5080
5081
5082
5083
5084
        Should be `cu_seqlens_padded` when cp_size > 1.
    cp_size: int, default = 1.
        Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
    cp_rank: int, default = 0.
        Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
5085
    """
5086
5087
5088
5089
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
5090
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
5091
5092
5093
5094
5095
5096

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

5097
5098
5099
5100
5101
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
5102
5103
5104
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
5105
    freqs = freqs[:cur_seq_len]
5106
    if tensor_format == "bshd":
5107
5108
5109
5110
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
5111

5112
5113
5114
5115
5116
5117
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
5118
    t = (t * cos_) + (_rotate_half(t) * sin_)
5119
5120
5121
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
5122
class _SplitAlongDim(torch.autograd.Function):
5123
5124
5125
    """"""

    @staticmethod
5126
5127
5128
5129
5130
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
5131
        squeeze=False,
5132
    ) -> Tuple[torch.Tensor, ...]:
5133
        # pylint: disable=missing-function-docstring
cyanguwa's avatar
cyanguwa committed
5134
5135
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
5136
5137
5138
5139
5140
5141
5142
5143
5144
5145
5146
5147
5148
5149
5150
5151
5152
        if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
            mixed_x_layer, Float8Tensor
        ):
            return tuple(
                Float8TensorBase(
                    fp8_scale_inv=mixed_x_layer._scale_inv,
                    fp8_dtype=mixed_x_layer._fp8_dtype,
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
                    quantizer=mixed_x_layer._quantizer,
                )
                for x in torch.split(
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
                    dim=split_dim,
                )
            )
5153
        if isinstance(mixed_x_layer, Float8Tensor):
5154
5155
5156
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
5157
5158
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
5159
5160
                )
                for x in torch.split(
5161
5162
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
5163
5164
5165
                    dim=split_dim,
                )
            )
5166
5167
5168
5169
        out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
        if squeeze:
            out_list = [x.squeeze(split_dim) for x in out_list]
        return out_list
5170
5171

    @staticmethod
5172
    def backward(ctx, *grad_outputs):
5173
        # pylint: disable=missing-function-docstring
5174
5175
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
5176
5177
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
5178
5179
5180
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
5181
5182
5183
5184
5185
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

5186
5187
5188
5189
5190
5191
5192
5193
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
5194
5195
5196
5197
5198
5199
5200
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
5201
5202
5203
                    noop_ok = False
                    break
            if noop_ok:
5204
5205
5206
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
5207
5208
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
5209
5210
5211
5212
5213
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
5214
                )
5215
5216
5217
5218
5219
                return (
                    Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape),
                    None,
                    None,
                )
5220
5221

            grad_outputs_data = [x._data for x in grad_outputs]
5222
            data = torch.cat(grad_outputs_data, dim=split_dim)
5223
            return (
5224
5225
                Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape),
                None,
5226
5227
5228
                None,
                None,
            )
5229
5230
        noop_ok = True
        strides = grad_outputs[0].stride()
5231
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
5232
        shape = list(grad_outputs[0].shape)
5233
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
5234
5235
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
5236
5237
5238
5239
5240
5241
5242
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
5243
5244
5245
                noop_ok = False
                break
        if noop_ok:
5246
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
5247
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
5248
            new_shape[split_dim] = sum(split_sizes)
5249
5250
5251
5252
5253
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
5254
            )
cyanguwa's avatar
cyanguwa committed
5255
            return ret, None, None
5256

5257
        return torch.cat(grad_outputs, dim=split_dim), None, None
5258
5259
5260
5261
5262
5263
5264
5265
5266


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
5267
        softmax_scale: float,
5268
        attention_type: str = "self",
5269
5270
5271
5272
5273
5274
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

5275
        self.softmax_scale = softmax_scale
5276
        self.attention_type = attention_type
5277
5278
5279
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

5280
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
5281
5282
5283
5284
5285
5286

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

5287
5288
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
5289
5290
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
5291

5292
5293
5294
5295
5296
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5297
        qkv_layout: str = "sbh3d",
5298
5299
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
5300
        attn_mask_type: str = "causal",
5301
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5302
        window_size: Optional[Tuple[int, int]] = None,
5303
5304
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
5305
        alibi_slopes: Optional[torch.Tensor] = None,
5306
    ) -> torch.Tensor:
5307
        """Unfused attention fprop"""
5308
5309
5310
5311
5312
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
5313
            # convert to sbhd and use sbhd implementation for now
5314
5315
5316
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
5317
5318
5319
5320
5321
        batch_size, max_seqlen_q, max_seqlen_kv = (
            query_layer.shape[1],
            query_layer.shape[0],
            key_layer.shape[0],
        )
5322
5323
5324
5325
5326
5327
5328
5329
5330

        attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
            max_seqlen_q,
            max_seqlen_kv,
            attn_mask_type=attn_mask_type,
            attention_mask=attention_mask,
            window_size=window_size,
            attention_type=self.attention_type,
        )
5331

5332
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
5333
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
5334
5335
5336
5337
5338
5339
5340
5341
5342

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

5343
        if key_layer.shape[2] != query_layer.shape[2]:
5344
5345
5346
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
5347
            key_layer = key_layer.repeat_interleave(
5348
5349
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
5350
            value_layer = value_layer.repeat_interleave(
5351
5352
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
5353

5354
        # [sq, b, np, hn] -> [sq, b * np, hn]
5355
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
5356
5357
5358
5359
5360
5361
5362
5363
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
5364
            dtype=query_layer.dtype,
5365
5366
5367
            device=torch.cuda.current_device(),
        )

5368
        scale = self.softmax_scale
5369
        if apply_qk_layer_scaling:
5370
            scale /= self.layer_number
5371
5372

        # Raw attention scores. [b * np, sq, sk]
5373
5374
5375
5376
5377
5378
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
5379
                alpha=scale,
5380
            ).view(*output_size)
5381
5382
5383
5384
5385
5386
5387

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
5388
            matmul_result = matmul_result.view(*output_size) + core_attention_bias
5389
            matmul_result *= scale
5390

5391
5392
5393
5394
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
5395
                _, core_attention_bias = get_alibi(
5396
5397
5398
                    output_size[1],
                    output_size[2],
                    output_size[3],
5399
5400
                    actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
                    actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
5401
5402
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
5403
                )
5404
5405
5406
5407
5408
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
5409
                alpha=scale,
5410
            )
5411
5412
            matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
                dtype=query_layer.dtype
5413
            )
5414
5415
5416

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
5417
        attention_probs = self.scale_mask_softmax(
5418
            matmul_result, attention_mask, attn_mask_type, softmax_scale
5419
        )
5420

5421
5422
5423
5424
5425
        # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
        # the columns (pad tokens from k) are already zeroed out during softmax
        if "padding" in attn_mask_type:
            attention_probs = attention_probs.masked_fill(attention_mask, 0)

5426
5427
5428
5429
5430
5431
5432
5433
5434
5435
5436
5437
5438
5439
5440
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
5441
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
5442
5443

        # change view [b * np, sq, sk]
5444
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
5445
5446
5447
5448
5449
5450
5451

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

5452
        if qkv_format == "sbhd":
5453
5454
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
5455

5456
5457
5458
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

5459
        if qkv_format == "bshd":
5460
5461
5462
5463
5464
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
5465
5466
5467
5468
5469
5470

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
5471
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
5472
5473

    @staticmethod
5474
5475
5476
5477
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
5478
        value_layer: torch.Tensor,
5479
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
5480
        # pylint: disable=missing-function-docstring
5481
5482
5483
5484
5485
5486
5487
5488
5489
5490
5491
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
5492
5493
5494
5495
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
5496
        dv: torch.Tensor,
5497
    ) -> Tuple[Union[torch.Tensor, None], ...]:
5498
        # pylint: disable=missing-function-docstring
5499
5500
5501
5502
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

5503

5504
def get_qkv_layout(
5505
5506
5507
5508
5509
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
5510
    """Get qkv layout.
5511

5512
5513
5514
5515
5516
5517
5518
5519
5520
5521
5522
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
5523
        `d` head size, and `t` the total number of tokens in a batch, i.e.
5524
5525
5526
5527
5528
5529
5530
5531
5532
5533
5534
5535
5536
5537
5538
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
5539
5540
5541
5542
5543
5544
5545
5546
5547
    q: torch.Tensor
        Query tensor. It may be different from input `q` as we try to fit tensors to
        a supported layout.
    k: torch.Tensor
        Key tensor. It may be different from input `k` as we try to fit tensors to
        a supported layout.
    v: torch.Tensor
        Value tensor. It may be different from input `v` as we try to fit tensors to
        a supported layout.
5548
    """
5549

5550
5551
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
5552

5553
    def run_iteratively(q, k, v):
5554
        # check data pointers
5555
5556
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
5557
        check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
5558
5559
5560
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

5561
5562
5563
5564
5565
5566
5567
        # check tensor shapes
        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = shape[:-1] == v.shape[:-1]

        # check tensor strides
5568
5569
        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
5570
5571
        check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
            sv / v.shape[-1] for sv in v.stride()[:-1]
5572
        )
5573

5574
5575
5576
5577
5578
5579
        # check tensor offsets for h3d and 3hd layouts
        prod_h_d = q.shape[-1] * q.shape[-2]
        check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v]))
        check_h3d_offsets = all(
            x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v])
        )
5580

5581
5582
5583
5584
5585
5586
        # check tensor offsets for hd_h2d and hd_2hd layouts
        prod_all_dims = [np.prod(x.shape) for x in [q, k]]
        offset = prod_all_dims[0] if check_ptrs_qkv else 0
        prod_h_d = k.shape[-1] * k.shape[-2]
        check_2hd_offsets = all(
            x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v])
5587
        )
5588
5589
        check_h2d_offsets = all(
            x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v])
5590
        )
5591

5592
5593
5594
5595
5596
5597
5598
5599
5600
5601
        # check tensor offsets for hd_hd_hd layouts
        check_hd_offsets_qkv = (
            all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v]))
            if check_ptrs_qkv
            else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v]))
        )
        check_hd_offsets_qk = (
            all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k]))
            if not check_ptrs_qkv and check_ptrs_qk
            else all(x.storage_offset() == 0 for i, x in enumerate([q, k]))
5602
        )
5603
5604
5605
5606
        check_hd_offsets_kv = (
            all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v]))
            if not check_ptrs_qkv and check_ptrs_kv
            else all(x.storage_offset() == 0 for i, x in enumerate([k, v]))
5607
        )
5608

5609
        if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
5610
            # sb3hd, bs3hd, t3hd
5611
            # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv
5612
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
5613
        elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
5614
            # sbh3d, bsh3d, th3d
5615
            # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv
5616
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
5617
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
5618
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
5619
5620
5621
            # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv
            # q and kv may be disjoint or consecutive in memory, and when consecutive, they may
            # have the same data pointer, i.e. check_ptrs_qkv=True
5622
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
5623
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
5624
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
5625
5626
5627
            # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv
            # q and kv may be disjoint or consecutive in memory, and when consecutive, they may
            # have the same data pointer, i.e. check_ptrs_qkv=True
5628
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
5629
5630
5631
5632
5633
        elif (
            check_strides_kv
            and check_shapes_kv
            and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk)
        ):
5634
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
5635
5636
5637
            # three chunks of memory, q, k and v, which may be disjoint or consecutive, and
            # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
            # check_ptrs_qk=True or check_ptrs_kv=True
5638
            qkv_layout = "_".join(list([qkv_format]) * 3)
5639
        else:
5640
            qkv_layout = "not_supported"
5641
5642
5643
5644

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
5645
    if qkv_layout == "not_supported":
5646
5647
5648
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
5649
    if qkv_layout == "not_supported":
5650
        raise RuntimeError("The provided qkv memory layout is not supported!")
5651

5652
    return qkv_layout, q, k, v
5653

5654

5655
def check_set_window_size(
5656
5657
5658
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
5659
5660
5661
5662
5663
5664
5665
5666
    """Check if sliding window size is compliant with attention mask type.
    If not, set it to the appropriate size.

         attn_mask_type                              |   window_size
    -------------------------------------------------------------------------
    no_mask, padding, arbitrary                      | (-1, -1) or (>=0, >=0)
    causal, padding_causal                           | (-1,  0) or (>=0, 0)
    causal_bottom_right, padding_causal_bottom_right | (-1,  0) or (>=0, 0)
5667
    """
5668
    orig_window_size = window_size
5669
    if "causal" in attn_mask_type:
5670
        if orig_window_size is None:
5671
            window_size = (-1, 0)
5672
5673
5674
        elif orig_window_size == (-1, -1) or (
            orig_window_size[0] >= 0 and orig_window_size[1] != 0
        ):
5675
5676
5677
5678
            window_size = (orig_window_size[0], 0)
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
5679
        elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
5680
5681
5682
5683
            assert False, (
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
    elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
5684
5685
5686
        if orig_window_size is None:
            window_size = (-1, -1)
        elif orig_window_size == (-1, 0):
5687
            window_size = (-1, -1)
5688
5689
5690
            warnings.warn(
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
5691
        elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
5692
5693
5694
5695
5696
            assert False, (
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
    else:
        assert False, "Invalid attn_mask_type: " + attn_mask_type
5697
    return window_size
5698

5699

5700
class FlashAttention(torch.nn.Module):
5701
    """Dot product attention, using HazyResearch flash-attn package:
5702
    https://github.com/Dao-AILab/flash-attention
5703
5704
5705
5706
    """

    def __init__(
        self,
5707
        softmax_scale: float,
5708
5709
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
5710
5711
        attention_type: str = "self",
        layer_number: Optional[int] = None,
5712
        deterministic: bool = False,
5713
5714
5715
    ) -> None:
        super().__init__()

5716
5717
5718
5719
5720
5721
5722
        if _flash_attn_is_installed:
            assert (
                _flash_attn_version >= _flash_attn_version_required
            ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
            assert (
                _flash_attn_version <= _flash_attn_max_version
            ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
5723

5724
        self.softmax_scale = softmax_scale
5725
5726
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
5727
5728
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
5729
        self.deterministic = deterministic
5730
5731
5732
5733
        self.logger = logging.getLogger("FlashAttention")
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
5734
5735
5736
5737
5738
5739

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5740
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5741
5742
5743
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5744
5745
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5746
        attn_mask_type: str = "causal",
5747
        window_size: Optional[Tuple[int, int]] = None,
5748
        alibi_slopes: Optional[torch.Tensor] = None,
5749
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
5750
        cp_global_ranks: List[int] = None,
5751
        cp_stream: torch.cuda.Stream = None,
5752
        cp_comm_type: str = "p2p",
5753
5754
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
5755
        quantizers=None,
5756
5757
5758
    ) -> torch.Tensor:
        """flash-attn fprop"""

5759
5760
5761
5762
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
5763
5764
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
5765
        ), "FlashAttention currently only supports CUDA tensors."
5766
5767
        assert (
            qkv_layout in QKVLayouts
5768
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
5769

5770
5771
5772
5773
5774
5775
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
5776
        context_parallel = cp_size > 1
5777

5778
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
5779

5780
5781
5782
5783
5784
5785
5786
5787
5788
5789
5790
5791
5792
        if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
            if qkv_format == "sbhd":
                # For now just 128, will make it more general in the future
                if (
                    query_layer.shape[-1] == 128
                    and query_layer.shape[0] * query_layer.shape[1] >= 512
                    and qkv_layout == "sbh3d"
                ):
                    query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                        query_layer, key_layer, value_layer
                    )
                else:
                    query_layer, key_layer, value_layer = [
5793
                        x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
5794
                    ]
5795
            if context_parallel:
5796
                query_layer, key_layer, value_layer = [
5797
5798
5799
5800
5801
                    x.contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        else:
            if qkv_format == "sbhd":
                query_layer._data, key_layer._data, value_layer._data = [
5802
                    x.transpose(0, 1)
5803
5804
                    for x in (query_layer._data, key_layer._data, value_layer._data)
                ]
5805
                query_layer, key_layer, value_layer = [
5806
                    Float8Tensor.make_like(x, data=x._data, shape=x._data.shape)
5807
5808
                    for x in (query_layer, key_layer, value_layer)
                ]
5809
            if context_parallel:
5810
5811
                query_layer._data, key_layer._data, value_layer._data = [
                    x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
5812
                ]
5813

5814
        batch_size = query_layer.shape[0]
5815

5816
        if qkv_format in ["sbhd", "bshd"]:
5817
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
5818
5819
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
5820
5821
5822

            if "padding" in attn_mask_type:
                assert not context_parallel, "Padding mask not supported with context parallelism!"
5823
5824
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
5825
                    x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
5826
5827
5828
5829
5830
5831
5832
                    for x in [query_layer, key_layer, value_layer]
                ]

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
5833
                    if cu_seqlens_q is None:
5834
5835
5836
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
5837
5838
5839
5840
5841
5842
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
5843
5844
                    )
                else:
5845
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
5846
5847
5848
5849
5850
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
5851
5852
5853
5854
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
5855
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
5856
            else:
5857
5858
5859
5860
5861
5862
5863
5864
5865
5866
5867
5868
5869
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
5870
5871
5872
5873
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
5874
5875
5876
5877
5878
5879
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
5880

5881
5882
5883
        if context_parallel and all(
            not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
        ):
5884
5885
5886
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
5887
            with self.attention_dropout_ctx():
5888
                output = attn_forward_func_with_cp(
5889
5890
5891
5892
5893
5894
5895
5896
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
5897
5898
                    cu_seqlens_q if qkv_format == "thd" else None,
                    cu_seqlens_kv if qkv_format == "thd" else None,
5899
                    self.attention_dropout if self.training else 0.0,
5900
5901
5902
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
5903
                    cp_comm_type,
5904
                    softmax_scale=self.softmax_scale,
5905
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
5906
                    attn_mask_type=attn_mask_type,
5907
                    deterministic=self.deterministic,
5908
                    window_size=window_size,
5909
                    quantizers=quantizers,
5910
                    pad_between_seqs=False,
5911
5912
                )
        else:
5913
5914

            from .cpu_offload import CPUOffloadEnabled
5915

5916
5917
5918
5919
5920
5921
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

5922
            with self.attention_dropout_ctx():
5923
                fa_optional_forward_kwargs = {}
5924
5925
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
5926
5927
5928
5929
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
5930
5931
5932
5933
                fa_optional_forward_args_thd = []
                if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
                    func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
                else:
5934
5935
                    if _flash_attn_2_5_7_plus:
                        fa_optional_forward_kwargs["block_table"] = None
5936
5937
5938
5939
5940
5941
5942
5943
5944
5945
                    func = (
                        flash_attn_varlen_func
                        if not _use_flash_attn_3
                        else flash_attn_varlen_func_v3
                    )
                    fa_optional_forward_args_thd.append(cu_seqlens_q)
                    fa_optional_forward_args_thd.append(cu_seqlens_kv)
                    fa_optional_forward_args_thd.append(max_seqlen_q)
                    fa_optional_forward_args_thd.append(max_seqlen_kv)
                if _use_flash_attn_3:
5946
5947
5948
                    fa_3_optional_forward_kwargs = {}
                    fa_3_optional_forward_kwargs["window_size"] = window_size
                    fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
5949
                    if fp8:
5950
                        QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
5951
                        torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
5952
                        torch_orig_dtype = query_layer.dtype
5953
5954
5955
5956
5957
5958
5959
5960
5961
5962
5963

                        def convert_to_torch_float8(tensor, dtype):
                            out = torch.Tensor().to(device=tensor.device, dtype=dtype)
                            out.set_(
                                tensor._data.untyped_storage(),
                                tensor._data.storage_offset(),
                                tensor._data.shape,
                                tensor._data.stride(),
                            )
                            return out

5964
5965
5966
5967
5968
                        # "fp8_mha" decides outputs in fp8, while inputs are inferred from
                        # the real dtype
                        assert isinstance(key_layer, query_layer.__class__) and isinstance(
                            value_layer, query_layer.__class__
                        ), "q, k, and v must have the same type."
5969
                        if not isinstance(query_layer, Float8Tensor):
5970
                            query_layer, key_layer, value_layer = (
5971
                                QKV_quantizer(x) for x in [query_layer, key_layer, value_layer]
5972
                            )
5973
5974
                        fa_3_optional_forward_kwargs["descale_q"] = (
                            query_layer._scale_inv.unsqueeze(0)
5975
                        )
5976
5977
                        fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze(
                            0
5978
                        )
5979
5980
                        fa_3_optional_forward_kwargs["descale_v"] = (
                            value_layer._scale_inv.unsqueeze(0)
5981
                        )
5982
5983
5984
                        query_layer, key_layer, value_layer = (
                            convert_to_torch_float8(x, torch_dtype)
                            for x in [query_layer, key_layer, value_layer]
5985
                        )
5986
5987
5988
5989
5990
5991
5992
5993
5994
5995
5996
5997
5998
5999
6000
6001
6002
6003
6004
6005
6006
6007
6008
6009
6010
                    try:
                        output, _ = func(
                            query_layer,
                            key_layer,
                            value_layer,
                            *fa_optional_forward_args_thd,
                            softmax_scale=self.softmax_scale,
                            causal="causal" in attn_mask_type,
                            **fa_3_optional_forward_kwargs,
                        )
                    except TypeError as e:
                        if _flash_attn_3_0_0_beta:
                            e.args = (
                                e.args[0]
                                + ". Please update your flash-attn v3 (beta) installation as it "
                                + "may have added more supported arguments to its API. \n"
                                + _flash_attn_3_installation_steps,
                            ) + e.args[1:]
                        raise

                    if fp8:
                        output = output.to(dtype=torch_orig_dtype)
                    if fp8 and fp8_meta["recipe"].fp8_mha:
                        O_quantizer = quantizers["scaling_fwd"][META_O]
                        output = O_quantizer(output)
6011
                else:
6012
6013
6014
6015
6016
6017
6018
6019
6020
                    output = func(
                        query_layer,
                        key_layer,
                        value_layer,
                        *fa_optional_forward_args_thd,
                        self.attention_dropout if self.training else 0.0,
                        softmax_scale=self.softmax_scale,
                        causal="causal" in attn_mask_type,
                        **fa_optional_forward_kwargs,
6021
                    )
6022

6023
6024
6025
6026
6027
6028
6029
6030
6031
6032
6033
6034
6035
6036
6037
6038
6039
6040
6041
6042
6043
6044
6045
6046
6047
6048
6049
6050
6051
6052
6053
6054
6055
6056
6057
6058
6059
6060
6061
6062
6063
6064
6065
6066
6067
6068
6069
6070
6071
6072
6073
6074
6075
6076
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)

        if qkv_format == "sbhd":
            # (bs)hd -> bs(hd) -> sb(hd)
            if fp8 and fp8_meta["recipe"].fp8_mha:
                output_data = (
                    output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
                    .transpose(0, 1)
                    .contiguous()
                )
                output = Float8Tensor.make_like(
                    output,
                    data=output_data,
                    shape=output_data.shape,
                )
            else:
                output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
        elif qkv_format == "bshd":
            # (bs)hd -> bs(hd)
            output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
        elif qkv_format == "thd":
            # thd -> t(hd)
            output = output.reshape(output.shape[0], -1)

        return output.contiguous()


def _combine_tensors(
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    if isinstance(tensors[0], Float8Tensor):
        new_stride = list(tensors[0]._data.stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape)
    else:
        new_stride = list(tensors[0].stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
        combined_tensor.set_(
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
6077
6078
        )

6079
6080
    return combined_tensor

6081

6082
6083
6084
6085
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
6086
6087
6088
6089
6090
6091
6092
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
6093
6094
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
6095
6096
6097
6098
6099
6100
6101
6102
6103
6104
        q,
        k,
        v,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
6105
        window_size,
6106
6107
6108
6109
6110
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
6111
        quantizers,
6112
        deterministic,
6113
    ):
6114
        # pylint: disable=missing-function-docstring
6115
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
6116
        is_input_fp8 = False
6117
        is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
6118
6119
6120
6121

        # FP16/BF16 attn:                  fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = True:  fake_dtype = torch.float8_e4m3fn
6122
6123
6124
6125
6126
        fake_dtype = q.dtype

        QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
            get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
        )
6127
6128
        if fp8:
            fused_attention_backend = FusedAttnBackend["FP8"]
6129
6130
6131
            assert isinstance(k, q.__class__) and isinstance(
                v, q.__class__
            ), "q, k, and v must have the same type."
6132

6133
            is_input_fp8 = isinstance(q, Float8Tensor)
6134
            q_fp8, k_fp8, v_fp8 = None, None, None
6135
            if is_input_fp8:
6136
                q_fp8, k_fp8, v_fp8 = q, k, v
6137
6138
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6139
                qkv_group = len(qkv_layout.split("_"))
6140
6141
6142
6143
6144
6145
6146
6147
6148
6149
6150
6151
6152
6153
6154
6155
6156
6157
6158
6159
                match qkv_group:
                    case 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                        qkv_fp8 = QKV_quantizer(qkv)
                        q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True)
                    case 2:
                        q_fp8 = QKV_quantizer(q)
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                        kv_fp8 = QKV_quantizer(kv_c)
                        k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1], True)
                    case 3:
                        q_fp8 = QKV_quantizer(q)
                        k_fp8 = QKV_quantizer(k)
                        v_fp8 = QKV_quantizer(v)
                    case _:
                        raise "Invalid qkv_layout " + qkv_layout
6160
            # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
6161
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
6162
6163
6164
6165
6166
6167
6168
6169
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
6170
                fake_dtype,
6171
6172
                fused_attention_backend,
                attn_bias,
6173
6174
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6175
6176
                S_quantizer,
                O_quantizer,
6177
6178
6179
6180
6181
6182
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6183
                window_size,
6184
6185
                rng_gen,
            )
6186
            if is_output_fp8:
6187
                out_ret = out_fp8
6188
            else:
6189
                out_ret = out_fp8.dequantize().view(out_fp8.shape)
6190
6191
            # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16
            # is_output_fp8 = True:  out_save.dtype = torch.float8_e4m3fn
6192
6193
            out_save = out_ret

6194
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
6195
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6196
6197
6198
6199
6200
6201
                if is_input_fp8:
                    qkv_group = len(qkv_layout.split("_"))
                    if qkv_group == 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
6202
6203
                        qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape)
                        q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
6204
                    if qkv_group == 2:
6205
                        q = q.dequantize()
6206
6207
6208
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
6209
6210
                        kv_no_fp8 = kv.dequantize()
                        k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True)
6211
                    if qkv_group == 3:
6212
6213
6214
                        q = q.dequantize()
                        k = k.dequantize()
                        v = v.dequantize()
6215
                if is_output_fp8:
6216
6217
6218
                    out_save = out_fp8.dequantize()

            fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
6219
        else:
6220
            # q, k, v, out_ret: torch.float16 or torch.bfloat16
6221
            out_ret, aux_ctx_tensors = fused_attn_fwd(
6222
6223
6224
6225
6226
6227
6228
6229
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
6230
                fake_dtype,
6231
6232
                fused_attention_backend,
                attn_bias,
6233
6234
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6235
6236
                None,  # s_quantizer
                None,  # o_quantizer
6237
6238
6239
6240
6241
6242
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6243
                window_size,
6244
6245
                rng_gen,
            )
6246
            out_save = out_ret
6247
            fp8_tensors = (None, None, None, None)
6248

6249
6250
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

6251
        from .cpu_offload import CPUOffloadEnabled
6252

6253
        if CPUOffloadEnabled:
6254
6255
6256
6257
6258
6259
6260
            if ctx.fp8:
                tensor_list = fp8_tensors
            else:
                tensor_list = [q, k, v, out_save]

            tensor_list.extend(aux_ctx_tensors)

6261
            qkv_layout = "sbhd_sbhd_sbhd"
6262
6263
6264
6265
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

6266
6267
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
6268
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
6269
6270
        tensors_to_save, tensor_objects = prepare_for_saving(
            *fp8_tensors,
6271
6272
6273
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
6274
6275
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6276
6277
            *aux_ctx_tensors,
        )
6278
6279
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
6280
        ctx.fp8_meta = fp8_meta
6281
6282
6283
6284
6285
6286

        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = S_quantizer

6287
6288
6289
6290
6291
6292
6293
6294
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
6295
        ctx.window_size = window_size
6296
        ctx.fused_attention_backend = (
6297
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
6298
        )
6299
        ctx.use_FAv2_bwd = use_FAv2_bwd
6300
        ctx.deterministic = deterministic
6301

6302
        return out_ret
6303
6304
6305

    @staticmethod
    def backward(ctx, d_out):
6306
        # pylint: disable=missing-function-docstring
6307
        if ctx.is_output_fp8:
6308
6309
6310
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
6311

6312
6313
6314
6315
6316
        # FP16/BF16 attn:                  fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = True:  fake_dtype = torch.float8_e5m2
        fake_dtype = d_out.dtype

6317
        d_out = d_out.contiguous()
6318
        (
6319
6320
6321
6322
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
6323
6324
6325
6326
6327
6328
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
6329
6330
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6331
6332
6333
6334
6335
            *other_tensors,
        ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)

        aux_ctx_tensors = other_tensors

6336
6337
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
6338
        rest = [None]
6339
        if ctx.use_FAv2_bwd:
6340
            softmax_lse, rng_state = aux_ctx_tensors
6341
6342
6343
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
6344
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
6345
            flash_attn_cuda_bwd(
6346
6347
6348
6349
6350
6351
6352
6353
6354
6355
6356
6357
6358
6359
6360
6361
6362
6363
6364
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
6365
            )
6366
6367
6368
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
6369
        else:
6370
6371
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
6372
                    if ctx.is_output_fp8:
6373
6374
                        d_out_fp8 = d_out
                    else:
6375
                        d_out_fp8 = ctx.dO_quantizer(d_out)
6376
6377
6378
                    dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
                    # q_fp8, k_fp8, v_fp8, out_fp8:      torch.float8_e4m3fn
                    # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
6379
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
6380
6381
6382
6383
6384
6385
6386
6387
6388
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
6389
6390
                        fake_dtype,
                        dqkv_dtype,
6391
                        aux_ctx_tensors,
6392
                        ctx.fused_attention_backend,
6393
6394
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6395
6396
6397
                        ctx.S_quantizer,
                        ctx.dP_quantizer,
                        ctx.dQKV_quantizer,
6398
6399
6400
6401
6402
6403
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6404
6405
                        ctx.window_size,
                        ctx.deterministic,
6406
                    )
6407

6408
6409
                    # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
                    # is_input_fp8 = True:  dq, dk, dv: torch.float8_e5m2
6410
                    if not ctx.is_input_fp8:
6411
                        qkv_group = len(ctx.qkv_layout.split("_"))
6412
                        if qkv_group == 1:
6413
                            dim = ctx.qkv_layout.find("3")
6414
6415
                            dqkv_fp8_data = _combine_tensors(
                                [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim
6416
                            )
6417
6418
6419
6420
6421
                            dqkv_fp8 = dq_fp8.make_like(
                                tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape
                            )
                            dqkv = dqkv_fp8.dequantize()
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True)
6422
                        if qkv_group == 2:
6423
                            dq = dq_fp8.dequantize()
6424
6425
6426
6427
6428
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
6429
6430
                            dkv = dkv_c_fp8.dequantize()
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1], True)
6431
                        if qkv_group == 3:
6432
6433
6434
6435
6436
                            dq = dq_fp8.dequantize()
                            dk = dk_fp8.dequantize()
                            dv = dv_fp8.dequantize()
                    else:
                        dq, dk, dv = dq_fp8, dk_fp8, dv_fp8
6437
                else:
6438
6439
                    if isinstance(d_out, QuantizedTensor):
                        d_out = d_out.dequantize()
6440
6441
                    dqkv_dtype = TE_DType[d_out.dtype]
                    # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16
6442
                    dq, dk, dv, *rest = fused_attn_bwd(
6443
6444
6445
6446
6447
6448
6449
6450
6451
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
6452
6453
                        fake_dtype,
                        dqkv_dtype,
6454
                        aux_ctx_tensors,
6455
                        ctx.fused_attention_backend,
6456
6457
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6458
6459
6460
6461
6462
6463
6464
6465
6466
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6467
6468
                        ctx.window_size,
                        ctx.deterministic,
6469
                    )
6470

6471
6472
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
6473
6474
6475
6476
6477
6478
6479
6480
6481
6482
6483
6484
6485
6486
6487
6488
6489
6490
6491
6492
6493
6494
6495
6496
6497
6498
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
6499
6500
                None,
                None,
6501
            )
6502
        # else, return (dqkv, dbias)
6503
6504
6505
6506
6507
6508
6509
6510
6511
6512
6513
6514
6515
6516
6517
6518
6519
6520
6521
6522
6523
6524
6525
6526
6527
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
6528
6529
            None,
            None,
6530
            None,
6531
        )
6532

6533

6534
class FusedAttention(torch.nn.Module):
6535
6536
6537
6538
6539
6540
6541
6542
6543
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

6544
6545
6546
6547
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
6548
    | attn_type     | self/cross              | self/cross                     |
6549
    | qkv_layout    |                         |                                |
6550
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
6551
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
6552
6553
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
6554
6555
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
6556
    | dropout       | yes                     | yes                            |
6557
6558
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
6559
    | output dtype  | fp16/bf16               | fp16/bf16                      |
6560
6561
6562
6563
    """

    def __init__(
        self,
6564
        softmax_scale: float,
6565
6566
6567
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
6568
6569
        layer_number: Optional[int] = None,
        deterministic: bool = False,
6570
6571
6572
    ) -> None:
        super().__init__()

6573
        self.softmax_scale = softmax_scale
6574
6575
6576
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
6577
6578
6579
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
6580
        self.layer_number = 1 if layer_number is None else layer_number
6581
        self.deterministic = deterministic
6582

6583
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
6584
6585
            """
            Temporarily remove fused_attention._extra_state as a missing key
6586
            or an unexpected key when loading Transformer Engine checkpoints.
6587
6588
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
6589
            phased out in Transformer Engine 2.0.
6590
6591
            """
            for key in incompatible_keys.missing_keys:
6592
                if "fused_attention._extra_state" in key:
6593
                    incompatible_keys.missing_keys.remove(key)
6594
6595
6596
6597
6598
6599
6600
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
6601

6602
6603
        self.register_load_state_dict_post_hook(remove_extra_states_check)

6604
    @no_torch_dynamo()
6605
6606
6607
6608
6609
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
6610
6611
6612
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
6613
6614
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
6615
6616
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
6617
        attn_mask_type: str = "causal",
6618
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6619
        window_size: Optional[Tuple[int, int]] = None,
6620
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
6621
6622
6623
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
6624
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
6625
6626
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
6627
        cp_comm_type: str = "p2p",
6628
6629
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
6630
        quantizers=None,
6631
        pad_between_seqs: bool = False,
6632
6633
    ) -> torch.Tensor:
        """fused attention fprop"""
6634
6635
6636
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
6637
6638
6639
6640
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
6641
6642
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
6643
        ), "FusedAttention only supports CUDA tensors."
6644
6645
        assert (
            qkv_layout in QKVLayouts
6646
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
6647

6648
6649
6650
6651
6652
6653
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
6654
        context_parallel = cp_size > 1
6655

6656
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
6657

6658
6659
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
6660
                batch_size, max_seqlen_q, max_seqlen_kv = (
6661
6662
6663
6664
6665
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
6666
                batch_size, max_seqlen_q, max_seqlen_kv = (
6667
6668
6669
6670
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
6671
6672
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
6673
            if "padding" in attn_mask_type:
6674
6675
                assert not context_parallel, "Padding mask not supported with context parallelism!"

6676
6677
6678
6679
6680
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
6681
                    if self.attention_type == "self":
6682
6683
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
6684
                    else:
6685
6686
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
6687
            else:
6688
6689
6690
6691
6692
6693
6694
6695
6696
6697
6698
6699
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
6700
6701
6702
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
6703
6704
6705
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
6706
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
6707

6708
        if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None):
6709
6710
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
6711

6712
6713
6714
6715
6716
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
6717

6718
6719
6720
6721
6722
6723
6724
6725
6726
6727
6728
        if fp8:
            assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                " is required for FP8 attention!"
            )
            assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
            assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
                "Amax reduction across TP+CP group is necessary when using context parallelism with"
                " FP8!"
            )

6729
        if context_parallel:
6730
            assert (
6731
6732
                fp8
                or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
6733
6734
6735
6736
6737
6738
6739
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
6740
6741
6742
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
6743
6744
6745
6746
6747
6748
6749
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
6750
6751
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
6752
                    self.attention_dropout if self.training else 0.0,
6753
6754
6755
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
6756
                    cp_comm_type,
6757
                    softmax_scale=self.softmax_scale,
6758
                    qkv_format=qkv_format,
6759
                    attn_mask_type=attn_mask_type,
6760
6761
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
6762
                    deterministic=self.deterministic,
6763
                    use_fused_attention=True,
6764
                    window_size=window_size,
6765
6766
                    fp8=fp8,
                    fp8_meta=fp8_meta,
6767
                    quantizers=quantizers,
6768
                    pad_between_seqs=pad_between_seqs,
6769
6770
                )
        else:
6771
6772
6773
6774
6775
6776
6777
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
6778
6779
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
6780
6781
6782
6783
6784
6785
6786
6787
6788
6789
                    query_layer,
                    key_layer,
                    value_layer,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
6790
                    window_size,
6791
6792
6793
6794
6795
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
6796
                    quantizers,
6797
                    self.deterministic,
6798
                )
6799

6800
6801
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
6802
6803


6804
class DotProductAttention(TransformerEngineBaseModule):
6805
6806
6807
6808
6809
6810
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

6811
        Argument :attr:`attention_mask` in the `forward` call is only used when
6812
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
6813
6814
6815

    .. warning::

6816
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
6817
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
6818
6819
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
6820

6821
6822
6823
6824
6825
6826
6827
    .. note::

        Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
        As the FP8 attention support expands from one backend to multiple backends, the location
        of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).


6828
6829
6830
6831
    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
6832
6833
6834
    kv_channels : Union[int, Tuple[int, int]]
                the head size in key and value tensors. If the same, :attr:`kv_channels` can be
                an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
6835
6836
6837
6838
6839
6840
6841
6842
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
6843
6844
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
6845
    attn_mask_type: str, default = `causal`
6846
                   type of attention mask passed into softmax operation, options are "`no_mask`",
6847
6848
6849
6850
6851
6852
6853
6854
6855
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
6856
                   "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
6857
6858
6859
6860
6861
6862
6863
6864
6865
6866
6867
6868
6869
6870
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
6871
6872
6873
6874
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
6875
6876
6877
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
6878
                be overridden by :attr:`window_size` in `forward` as well.
6879
6880
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
6881
6882
6883
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
6884
6885
6886
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
6887
               `h` the number of heads, `d` head size, and `t` the total number of tokens
6888
6889
6890
6891
6892
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
6893
               For that, please use `get_qkv_layout` to gain the layout information.
6894
6895
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
6896
                `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
6897
6898
6899
6900
6901
6902
6903
6904
6905

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
6906
    cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
6907
              context parallel process group.
6908
6909
6910
              ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
              List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
              and cp_group[1] are for a2a and p2p communications respectively.
6911
6912
6913
6914
6915
6916
6917
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
6918
    cp_comm_type : str, default = `p2p`
6919
                  inter-gpu communication type for context parallelism.
6920
                  Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
6921
6922
6923
6924
6925
6926
                  "p2p": Exchange KV chunks with P2P communications in ring topology.
                         P2P is async and can be overlapped with attention compute.
                  "all_gather": All-gather to get full sequence of KV before attention.
                                The all-gather is not async, and cannot be overlapped.
                  "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                         group, and gather to get full sequence of QKV.
6927
6928
6929
                  "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                  across each CP sub-group (e.g., via NVLink), then exchanging KV with
                  p2p between sub-groups (e.g., via IBLink).
6930
6931
6932
6933
6934
    """

    def __init__(
        self,
        num_attention_heads: int,
6935
        kv_channels: Union[int, Tuple[int, int]],
6936
        num_gqa_groups: Optional[int] = None,
6937
        attention_dropout: float = 0.0,
6938
        qkv_format: str = "sbhd",
6939
        attn_mask_type: str = "causal",
6940
        window_size: Optional[Tuple[int, int]] = None,
6941
6942
6943
6944
6945
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
6946
        attention_type: str = "self",
6947
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
6948
        cp_global_ranks: List[int] = None,
6949
        cp_stream: torch.cuda.Stream = None,
6950
        cp_comm_type: str = "p2p",
6951
        softmax_scale: Optional[float] = None,
6952
6953
6954
    ) -> None:
        super().__init__()

6955
        self.logger = logging.getLogger("DotProductAttention")
6956
6957
6958
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
6959
        self.qkv_format = qkv_format
6960
        attn_mask_type = attn_mask_type.replace(",", "_")
6961
6962
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
6963
        self.attn_mask_type = attn_mask_type
6964
        self.window_size = check_set_window_size(attn_mask_type, window_size)
6965
6966
6967
6968
6969
6970
6971
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
6972
        self.get_rng_state_tracker = get_rng_state_tracker
6973
        self.num_attention_heads = num_attention_heads
6974
        self.layer_number = 1 if layer_number is None else layer_number
6975
6976
6977
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
6978
        self.cp_comm_type = cp_comm_type
6979

6980
6981
6982
6983
6984
6985
        self.hidden_size_per_attention_head_k = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[0]
        )
        self.hidden_size_per_attention_head_v = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[1]
        )
6986

6987
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
6988
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
6989

6990
6991
6992
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
6993

6994
        self.rng_states_tracker = None
6995
6996
6997
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
6998
6999
7000
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
7001

7002
        if softmax_scale is None:
7003
7004
7005
            softmax_scale = 1.0 / math.sqrt(
                kv_channels if isinstance(kv_channels, int) else kv_channels[0]
            )
7006

7007
7008
7009
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
7010
        )
7011
7012
7013
7014
7015
7016
7017
7018
7019
7020
7021
7022
7023
7024
7025
7026
7027
7028
7029
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
7030

7031
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
7032
7033
7034
7035

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

7036
7037
7038
7039
7040
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

7041
7042
7043
7044
7045
7046
7047
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
7048

7049
        # Instantiating three types since use of flash-attn and FusedAttention
7050
        # might be ruled out due to forward inputs.
7051
7052
7053
7054
7055
7056
7057
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
7058

7059
        self.unfused_attention = UnfusedDotProductAttention(
7060
7061
7062
7063
            softmax_scale,
            attention_type=attention_type,
            **attn_kwargs,
            layer_number=layer_number,
7064
        )
7065

7066
7067
7068
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
7069
7070
            when loading older Transformer Engine checkpoints. Will phase out
            this hook in Transformer Engine 2.0.
7071
7072
7073
7074
7075
7076
7077
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

7078
7079
7080
7081
7082
7083
7084
7085
7086
7087
7088
7089
7090
7091
7092
7093
7094
7095
7096
7097
7098
7099
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        """
        This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
        metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
        `core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
        <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
        """
        fused_attn_key = False
        dot_product_attn_key = False
        for k in state_dict.keys():
            if "core_attention.fused_attention._extra_state" in k:
                fused_attn_key = True
            if "core_attention._extra_state" in k:
                dot_product_attn_key = True
        if fused_attn_key and not dot_product_attn_key:
            prefix = prefix + "fused_attention."
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )

7100
7101
7102
7103
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
7104
        **forward_kwargs: Dict[str, Any],
7105
7106
7107
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

7108
7109
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
7110
7111
7112

        hidden_states = checkpoint(
            custom_forward,
7113
7114
7115
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
7116
            *forward_args,
7117
            **forward_kwargs,
7118
7119
7120
7121
        )

        return hidden_states

7122
7123
    def set_context_parallel_group(
        self,
7124
        cp_group: Union[dist_group_type, List[dist_group_type], None],
7125
7126
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
7127
        cp_comm_type: str = "p2p",
7128
    ) -> None:
7129
7130
7131
7132
7133
7134
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
7135
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
7136
                  context parallel process group.
7137
7138
7139
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
7140
7141
7142
7143
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
7144
        cp_comm_type : str, default = `p2p`
7145
                      inter-gpu communication type for context parallelism.
7146
                      Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
7147
7148
7149
7150
7151
7152
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
7153
7154
7155
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
7156
        """
7157
7158
7159
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
7160
        self.cp_comm_type = cp_comm_type
7161

7162
    @no_torch_dynamo(recursive=False)
7163
7164
7165
7166
7167
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
7168
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
7169
7170
7171
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
7172
7173
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
7174
7175
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
7176
        attn_mask_type: Optional[str] = None,
7177
        window_size: Optional[Tuple[int, int]] = None,
7178
        checkpoint_core_attention: bool = False,
7179
7180
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
7181
        alibi_slopes: Optional[torch.Tensor] = None,
7182
        fast_zero_fill: bool = True,
7183
        inference_params: Optional[InferenceParams] = None,
7184
        pad_between_seqs: Optional[bool] = None,
7185
7186
7187
7188
7189
7190
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

7191
7192
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
7193

7194
7195
        .. note::

7196
7197
7198
7199
7200
7201
7202
7203
7204
7205
7206
7207
7208
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
7209
            and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
7210
7211
7212
7213
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
7214
7215
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
7216
            optimizations in FusedAttention. When unset, Transformer Engine determines the code path
7217
7218
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
7219

7220
7221
7222
7223
7224
7225
7226
7227
7228
7229
7230
7231
7232
7233
7234
7235
7236
7237
7238
7239
7240
7241
7242
7243
7244
7245
7246
7247
7248
7249
7250
7251
7252
7253
7254
7255
7256
7257
7258
7259
7260
7261
7262
7263
7264
7265
7266
7267
7268
7269
7270
7271
7272
7273
        .. note::
            .. _cu_seqlens note:

            When training data has variable sequence lengths, users have two options.

            1. Manipulate the data and pad all sequences to the same length. Use
               :attr:`qkv_format` = {"bshd", "sbhd"} and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
               (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
               the real sequence length information. For example, a batch of 3 sequences
               [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative
               sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

            2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
               as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed
               without any padding, and the sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

               In certain use cases, a varying number of identifier tokens are inserted between
               sequences. These tokens do not participate in the attention calculation.
               :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
               in such cases to correctly identify the start and end of each sequence in a batch.
               For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and
               :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13]
               for self-attention.

        .. note::
            .. _max_seqlen note:

            When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch.
            :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of
            :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will
            infer them as such.

            When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and
            :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch.
            When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`.
            This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this
            overhead, users are recommended to obtain the maximum sequence lengths from the data loaders
            and pass them in.

            - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch,
              dynamic shapes need to be supported for tensor construction. FlashAttention and
              UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static
              to create graphs before performance heuristics analysis. To reduce the number of graphs created
              per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size,
              :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of
              :attr:`query_layer`, "t" dimension of :attr:`key_layer`}.

7274
7275
7276
7277
7278
7279
7280
7281
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
7282
7283
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
7284
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
7285
7286
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
7287
7288
7289
7290
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
7291
7292
7293
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
7294
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
7295
                   with shape [batch_size + 1] and dtype torch.int32.
7296
                   See :ref:`note<cu_seqlens note>` for more details.
7297
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
7298
7299
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
7300
                   See :ref:`note<cu_seqlens note>` for more details.
7301
7302
7303
7304
7305
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
7306
                   See :ref:`note<cu_seqlens note>` for more details.
7307
7308
7309
7310
7311
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
7312
                   See :ref:`note<cu_seqlens note>` for more details.
7313
7314
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
7315
                      See :ref:`note<max_seqlen note>` for more details.
7316
7317
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
7318
                       See :ref:`note<max_seqlen note>` for more details.
7319
7320
7321
7322
7323
7324
7325
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
7326
        window_size: Optional[Tuple[int, int]], default = `None`
7327
                    Sliding window size for local attention.
7328
7329
7330
7331
7332
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
7333
        core_attention_bias_type: str, default = `no_bias`
7334
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
7335
        core_attention_bias: Optional[torch.Tensor], default = `None`
7336
7337
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
7338
7339
7340
7341
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
7342
        fast_zero_fill: bool, default = `True`
7343
                    Whether to use the fast path to set output tensors to 0 or not.
7344
7345
7346
7347
7348
7349
7350
7351
7352
7353
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
7354
7355
7356
        pad_between_seqs: Optional[bool], default = `None`
            If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
            If true, there are padding tokens between individual sequences in a packed batch.
7357
        """
7358

7359
7360
7361
7362
7363
7364
7365
7366
7367
        with self.prepare_forward(
            query_layer,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:
            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
7368
                        self.logger.warning(
7369
7370
7371
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
7372
7373
7374
7375
7376
7377
7378
7379
7380
7381
7382

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
7383

7384
7385
7386
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
7387
7388
7389
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
7390
7391
7392
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
7393
7394
7395
7396
7397
7398
7399
7400
            assert (
                key_layer.shape[-1] == self.hidden_size_per_attention_head_k
            ), f"Keys have head_dim = {key_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
            assert (
                value_layer.shape[-1] == self.hidden_size_per_attention_head_v
            ), f"Values have head_dim = {value_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
7401

7402
7403
7404
            if qkv_format is None:
                qkv_format = self.qkv_format

7405
7406
7407
7408
7409
7410
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
7411
            assert (
7412
7413
7414
7415
7416
7417
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
7418

7419
7420
7421
7422
            if window_size is None:
                window_size = self.window_size
            window_size = check_set_window_size(attn_mask_type, window_size)

7423
7424
7425
7426
7427
7428
7429
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
7430

7431
7432
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
7433

7434
7435
7436
7437
7438
                # convert causal to causal_bottom_right in inference when KV-caching is in use
                # so users can run with the same attn_mask_type for training and inference
                if attn_mask_type in ["causal", "padding_causal"]:
                    attn_mask_type = attn_mask_type + "_bottom_right"

7439
7440
7441
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
7442

7443
7444
7445
7446
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
7447

7448
7449
7450
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
7451

7452
7453
7454
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
7455

7456
7457
7458
7459
7460
7461
7462
7463
7464
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
7465

7466
7467
7468
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
7469

7470
7471
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
7472
7473

            assert (
7474
7475
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
7476
7477
7478
7479
            ), (
                "Keys and values must have num_gqa_group ="
                f" {self.num_gqa_groups_per_partition} heads!"
            )
7480
7481
7482
7483
7484
7485
7486
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
7487
                assert all(
7488
7489
7490
7491
7492
7493
7494
7495
7496
7497
7498
7499
7500
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
7501
                batch_size = len(cu_seqlens_q) - 1
7502
                if max_seqlen_q is None:
7503
7504
7505
7506
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
7507
                    max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
7508
                if max_seqlen_kv is None:
7509
7510
7511
7512
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
7513
                    max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
7514

7515
7516
7517
7518
7519
7520
            cp_size = 1
            if isinstance(self.cp_group, dist_group_type):
                cp_size = get_distributed_world_size(self.cp_group)
            elif isinstance(self.cp_group, list):
                for group in self.cp_group:
                    cp_size *= get_distributed_world_size(group)
7521
7522
            context_parallel = cp_size > 1

7523
            if qkv_format in ["sbhd", "bshd"]:
7524
                assert all(
7525
7526
7527
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
7528
7529
                    max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
7530
                    batch_size = query_layer.shape[1]
7531
                else:
7532
7533
                    max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
7534
                    batch_size = query_layer.shape[0]
7535
7536
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
7537
7538
7539
7540
7541
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
7542
                        the sequence dimension in 'query_layer'!"""
7543
7544
7545
7546
7547
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
7548
                        the sequence dimension in 'key_layer' and 'value_layer'!"""
7549
7550
7551
7552
7553
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
7554
                        if self.attention_type == "self":
7555
7556
7557
7558
7559
7560
7561
7562
7563
7564
7565
7566
7567
7568
7569
7570
                            cu_seqlens_q = get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                    else:
                        cu_seqlens_q = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
                        cu_seqlens_kv = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
7571

7572
7573
7574
7575
7576
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
7577
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
7578
7579
7580
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
7581
                qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
7582
7583
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
7584

7585
7586
7587
7588
7589
7590
7591
7592
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
7593
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
7594
7595
7596
7597
7598
7599
7600
7601
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
7602
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
7603
7604
7605
7606
7607
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

7608
7609
            core_attention_bias_shape = None
            if core_attention_bias is not None:
7610
                if (
7611
7612
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
7613
                ):
7614
7615
7616
7617
7618
7619
7620
7621
7622
7623
7624
7625
7626
7627
7628
7629
7630
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

7631
7632
7633
7634
7635
7636
7637
7638
7639
7640
7641
            if pad_between_seqs is None:
                if qkv_format == "thd":
                    pad_between_seqs = (
                        cu_seqlens_q_padded is not None
                        and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
                    ) or (
                        cu_seqlens_kv_padded is not None
                        and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
                    )
                else:
                    pad_between_seqs = False
7642

7643
            attention_params = AttentionParams(
7644
7645
7646
7647
7648
7649
7650
7651
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
7652
7653
                head_dim_qk=query_layer.shape[-1],
                head_dim_v=value_layer.shape[-1],
7654
7655
7656
7657
7658
7659
7660
7661
7662
7663
7664
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
7665
7666
                deterministic=self.deterministic,
                is_training=self.training,
7667
7668
7669
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
7670
            global _attention_backends, _use_flash_attn_3
7671
7672
7673
7674
7675
7676
7677
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
7678
                _use_flash_attn_3 = _flash_attn_3_is_installed
7679
7680
7681
7682
7683
7684
7685
7686
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
                ) = get_attention_backend(attention_params)
                if use_flash_attention:
7687
7688
                    self.logger.info(
                        "Running with FlashAttention backend (version %s)",
7689
                        _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version,
7690
                    )
7691
7692
7693
7694
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
7695
                    )
7696
7697
7698
7699
7700
7701
7702
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
7703

7704
7705
7706
7707
7708
7709
7710
7711
7712
7713
7714
7715
7716
7717
7718
7719
7720
7721
7722
7723
7724
7725
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
7726
                    cp_comm_type=self.cp_comm_type,
7727
7728
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
7729
7730
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
7731
                    quantizers=self.quantizers,
7732
                )
7733

7734
            if use_fused_attention:
7735
7736
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
7737
7738
7739
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
7740
7741
7742
7743
7744
7745
7746
                    fu_core_attention_bias_type = "post_scale_bias"
                    _, fu_core_attention_bias = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
7747
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
7748
                    )
7749
7750
7751
7752
7753
7754
7755
7756
7757
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
7758
7759
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
7760
7761
7762
7763
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
7764
                        window_size=window_size,
7765
7766
7767
7768
7769
7770
7771
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
7772
                        cp_comm_type=self.cp_comm_type,
7773
7774
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
7775
                        pad_between_seqs=pad_between_seqs,
7776
7777
                    )
                return self.fused_attention(
7778
7779
7780
7781
7782
7783
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
7784
7785
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
7786
7787
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
7788
7789
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
7790
                    window_size=window_size,
7791
                    fused_attention_backend=fused_attention_backend,
7792
7793
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
7794
7795
7796
7797
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
7798
                    cp_comm_type=self.cp_comm_type,
7799
7800
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
7801
                    quantizers=self.quantizers,
7802
                    pad_between_seqs=pad_between_seqs,
7803
                )
7804

7805
            from .cpu_offload import CPUOffloadEnabled
7806

7807
7808
7809
7810
7811
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
7812

7813
7814
7815
7816
7817
7818
7819
7820
7821
7822
7823
7824
            if use_unfused_attention:
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
7825
                        window_size=window_size,
7826
7827
7828
7829
7830
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
7831
7832
7833
                    query_layer,
                    key_layer,
                    value_layer,
7834
7835
7836
7837
7838
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
7839
                    window_size=window_size,
7840
7841
7842
7843
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
7844

7845
            raise ValueError("No dot product attention support for the provided inputs!")
7846
7847


7848
7849
7850
7851
7852
7853
7854
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

7855
7856
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
7857

7858
7859
7860
7861
7862
7863
7864
7865
7866
7867
7868
7869
7870
7871
7872
7873
7874
7875
7876
7877
7878
7879
7880
7881
7882
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
7883
7884
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
7885
                   default = `causal`
7886
7887
7888
7889
7890
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
7891
7892
7893
7894
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
7895
7896
7897
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
7898
                be overridden by :attr:`window_size` in `forward` as well.
7899
7900
7901
7902
7903
7904
7905
7906
7907
7908
7909
7910
7911
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
7912
7913
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
7914
7915
7916
7917
7918
7919
7920
7921
7922
7923
7924
7925
7926
7927
7928
7929
7930
7931
7932
7933
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
7934
          The device on which the parameters of the model will be allocated. It is the user's
7935
7936
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
7937
7938
7939
7940
7941
7942
7943
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
7944
            For that, please use `get_qkv_layout` to gain the layout information.
7945
7946
7947
7948
7949
7950
7951
7952
7953
7954
7955
7956
7957
7958
7959
7960
7961
7962
7963
7964
7965
7966
7967
7968
7969
7970
7971
7972
7973
7974
7975
7976
7977
7978
7979
7980
7981
7982
7983
7984

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
7985
7986
7987
7988
7989
7990
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
7991
7992
7993
7994
7995
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
7996
        layer_number: Optional[int] = None,
7997
        attn_mask_type: str = "causal",
7998
        window_size: Optional[Tuple[int, int]] = None,
7999
8000
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
8001
        num_gqa_groups: Optional[int] = None,
8002
8003
8004
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
8005
        params_dtype: Optional[torch.dtype] = None,
8006
        return_bias: bool = False,
8007
8008
8009
8010
8011
8012
8013
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
8014
        ub_overlap_ag: bool = False,
8015
8016
8017
8018
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
8019
        bias: bool = True,
8020
        normalization: str = "LayerNorm",
8021
        device: Union[torch.device, str] = "cuda",
8022
        qkv_format: str = "sbhd",
8023
8024
    ) -> None:
        super().__init__()
8025

8026
        self.qkv_format = qkv_format
8027
        self.attn_mask_type = attn_mask_type
8028
        self.window_size = check_set_window_size(attn_mask_type, window_size)
8029
        self.layer_number = layer_number
8030
8031
8032
8033
8034
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
8035
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
8036
        self.num_attention_heads = num_attention_heads
8037
        self.return_bias = return_bias
8038
8039
        self.cp_size = 1
        self.cp_rank = 0
8040
8041
8042
8043
8044
8045
8046

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
8047
8048
8049
8050
8051

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

8052
8053
8054
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
8055
8056
8057
8058
8059
8060

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
8061
8062
8063
8064
8065
8066
8067
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
8068
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
8069
8070
8071
8072

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
8073
8074
8075
8076
8077
8078
8079

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
8080
            "params_dtype": self.params_dtype,
8081
            "device": device,
8082
8083
8084
8085
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
8086
        if self.attention_type == "self":
8087
8088
            parameters_split = None
            if not fuse_qkv_params:
8089
8090
8091
8092
8093
8094
8095
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
8096
8097
8098
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
8099
                    self.hidden_size_q + 2 * self.hidden_size_kv,
8100
8101
8102
8103
8104
8105
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
8106
                    parameters_split=parameters_split,
8107
8108
8109
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
8110
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
8111
                    ub_overlap_ag=ub_overlap_ag,
8112
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8113
                    ub_name="qkv",
8114
8115
8116
8117
8118
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
8119
                    self.hidden_size_q + 2 * self.hidden_size_kv,
8120
8121
8122
8123
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
8124
                    parameters_split=parameters_split,
8125
8126
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
8127
        elif self.attention_type == "cross":
8128
8129
8130
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
8131
                    self.hidden_size_q,
8132
8133
8134
8135
8136
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
8137
                    parameters_split=("query",) if not fuse_qkv_params else None,
8138
8139
8140
8141
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
8142
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
8143
                    ub_overlap_ag=ub_overlap_ag,
8144
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8145
                    ub_name="qkv",
8146
8147
8148
8149
8150
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
8151
                    self.hidden_size_q,
8152
8153
8154
8155
8156
8157
8158
8159
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
8160
                2 * self.hidden_size_kv,
8161
8162
8163
8164
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
8165
                parameters_split=("key", "value") if not fuse_qkv_params else None,
8166
8167
8168
8169
8170
8171
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
8172
            self.hidden_size_per_attention_head,
8173
8174
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
8175
            qkv_format=self.qkv_format,
8176
8177
8178
8179
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
8180
            layer_number=self.layer_number,
8181
            attention_type=self.attention_type,
8182
8183
8184
8185
        )

        # Linear
        self.proj = Linear(
8186
            self.hidden_size_q,
8187
8188
8189
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
8190
            return_bias=return_bias,
8191
            parallel_mode="row" if set_parallel_mode else None,
8192
8193
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8194
            ub_name="proj",
8195
8196
8197
8198
            **common_gemm_kwargs,
        )

    def _allocate_memory(
8199
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
8200
    ) -> torch.Tensor:
8201
        """Allocates memory for KV cache."""
8202
8203
8204
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
8205
            self.num_gqa_groups_per_partition,
8206
            self.hidden_size_per_attention_head,
8207
            dtype=dtype,
8208
8209
8210
8211
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
8212
8213
8214
8215
8216
8217
8218
8219
8220
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
8221
8222
        self.tp_group = tp_group

8223
    def set_context_parallel_group(
8224
        self,
8225
        cp_group: Union[dist_group_type, List[dist_group_type], None],
8226
        cp_global_ranks: List[int],
8227
        cp_stream: torch.cuda.Stream,
8228
        cp_comm_type: str = "p2p",
8229
    ) -> None:
8230
8231
8232
8233
8234
8235
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
8236
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
8237
                  context parallel process group.
8238
8239
8240
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
8241
8242
8243
8244
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
8245
        cp_comm_type : str, default = `p2p`
8246
                      inter-gpu communication type for context parallelism.
8247
                      Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
8248
8249
8250
8251
8252
8253
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
8254
8255
8256
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
8257
        """
8258
8259
8260
8261
8262
8263
8264
8265
8266
8267
8268
8269
8270
8271
8272
        if isinstance(cp_group, dist_group_type):
            self.cp_size = get_distributed_world_size(cp_group)
            self.cp_rank = get_distributed_rank(cp_group)
        elif isinstance(cp_group, list):
            assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
            assert (
                cp_comm_type == "a2a+p2p"
            ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
            cp_size_a2a = get_distributed_world_size(cp_group[0])
            cp_rank_a2a = get_distributed_rank(cp_group[0])
            cp_size_p2p = get_distributed_world_size(cp_group[1])
            cp_rank_p2p = get_distributed_rank(cp_group[1])
            self.cp_size = cp_size_a2a * cp_size_p2p
            self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a

8273
8274
8275
8276
8277
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
8278
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
8279

8280
8281
8282
    def forward(
        self,
        hidden_states: torch.Tensor,
8283
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8284
        encoder_output: Optional[torch.Tensor] = None,
8285
        attn_mask_type: Optional[str] = None,
8286
        window_size: Optional[Tuple[int, int]] = None,
8287
8288
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
8289
        inference_params: Optional[InferenceParams] = None,
8290
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8291
8292
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
8293
        alibi_slopes: Optional[torch.Tensor] = None,
8294
8295
8296
8297
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
8298
        fast_zero_fill: bool = True,
8299
        pad_between_seqs: Optional[bool] = None,
8300
    ) -> Tuple[Union[torch.Tensor, None], ...]:
8301
8302
8303
8304
8305
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

8306
8307
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
8308
8309
8310
8311
8312

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
8313
8314
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
8315
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
8316
8317
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
8318
8319
8320
8321
8322
8323
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
8324
                       default = `None`
8325
8326
8327
8328
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
8329
8330
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
8331
8332
8333
8334
8335
8336
8337
8338
8339
8340
8341
8342
8343
8344
8345
8346
8347
8348
8349
8350
8351
8352
8353
8354
8355
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
8356
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
8357
        core_attention_bias: Optional[torch.Tensor], default = `None`
8358
8359
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
8360
8361
8362
8363
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
8364
8365
8366
8367
8368
8369
8370
8371
8372
8373
8374
8375
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
8376
8377
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
8378
8379
8380
        pad_between_seqs: Optional[bool], default = `None`
            If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
            If true, there are padding tokens between individual sequences in a packed batch.
8381
        """
8382
8383
        # hidden_states: [sq, b, h]

8384
        if attn_mask_type is None:
8385
            attn_mask_type = self.attn_mask_type
8386
8387
        if window_size is None:
            window_size = self.window_size
8388
        window_size = check_set_window_size(attn_mask_type, window_size)
8389

8390
        if "padding" in attn_mask_type and attention_mask is not None:
8391
8392
            for mask in attention_mask:
                assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
8393

8394
8395
8396
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
8397

8398
        # =================================================
8399
        # Pre-allocate memory for key-values for inference
8400
8401
8402
        # =================================================

        if inference_params and self.layer_number is not None:
8403
8404
8405
            assert (
                self.qkv_format != "thd"
            ), "qkv_format == thd is not supported for an inference with KV-cache!"
8406
            if self.layer_number not in inference_params.key_value_memory_dict:
8407
                inf_max_seq_len = inference_params.max_sequence_length
8408
8409
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
8410
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8411
8412
                )
                inference_value_memory = self._allocate_memory(
8413
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8414
8415
8416
8417
8418
8419
8420
8421
8422
8423
8424
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

8425
        # ======================
8426
        # Query, Key, and Value
8427
        # ======================
8428

8429
8430
8431
8432
8433
        fp8_mha = (
            FP8GlobalStateManager.is_fp8_enabled()
            and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
        )

8434
        layernorm_output = None
cyanguwa's avatar
cyanguwa committed
8435
8436
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
8437
8438
8439
8440
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8441
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8442
8443
8444
8445
8446
8447
8448
8449
8450
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8451
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8452
8453
                )

8454
8455
8456
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
8457
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
8458
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
8459
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
8460
8461
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
8462
8463
8464
8465
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
8466
8467
8468
8469
8470
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
8471
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
8472
8473
8474
                )
                # split along third last dimension
                split_dim = -3
8475
8476
8477

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
8478
8479
8480
8481
8482
8483
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
8484
8485
8486
            query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
            )
cyanguwa's avatar
cyanguwa committed
8487

8488
8489
8490
8491
8492
8493
8494
8495
8496
8497
8498
8499
            if self.qkv_format == "thd":
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
            else:
                # query: -> [sq, b, np, hn]
                # key, value: -> [sq, b, ng, hn]
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
cyanguwa's avatar
cyanguwa committed
8500
8501
        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
8502
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
8503
                encoder_output,
8504
                is_first_microbatch=is_first_microbatch,
8505
                fp8_output=fp8_mha and rotary_pos_emb is None,
8506
8507
8508
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
8509
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
8510
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
8511
                    self.num_gqa_groups_per_partition,
8512
8513
8514
8515
8516
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
8517
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
8518
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
8519
                    2 * self.num_gqa_groups_per_partition,
8520
8521
8522
8523
8524
8525
8526
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
8527
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
8528
8529
8530
8531
8532
            key_layer, value_layer = _SplitAlongDim.apply(
                mixed_kv_layer,
                split_dim,
                mixed_kv_layer.shape[split_dim] // 2,
            )
8533
8534
8535
8536
8537
8538
8539
8540
8541
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
8542
8543
8544
8545
8546
8547

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8548
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8549
8550
8551
8552
8553
8554
8555
8556
8557
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8558
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8559
8560
8561
8562
8563
8564
8565
8566
8567
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

8568
8569
8570
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
8571

8572
        if rotary_pos_emb is not None:
8573
8574
8575
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
8576
            # duplicate the pos_emb for self attention
8577
            if not isinstance(rotary_pos_emb, tuple):
8578
                rotary_pos_emb = (rotary_pos_emb,) * 2
8579
8580

            q_pos_emb, k_pos_emb = rotary_pos_emb
8581
8582
8583
8584
8585
8586
8587

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)
8588
8589
                else:
                    raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.")
8590
8591
8592
8593
8594
8595
8596

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

8597
8598
8599
8600
8601
8602
8603
8604
8605
8606
8607
8608
8609
8610
8611
8612
8613
8614
            query_layer = apply_rotary_pos_emb(
                query_layer,
                q_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_q,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
            key_layer = apply_rotary_pos_emb(
                key_layer,
                k_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_kv,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
8615

8616
8617
8618
8619
        # ===========================
        # Core attention computation
        # ===========================

8620
8621
8622
8623
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
8624
            qkv_format=self.qkv_format,
8625
8626
8627
8628
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
8629
8630
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
8631
            window_size=window_size,
8632
8633
8634
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
8635
            alibi_slopes=alibi_slopes,
8636
            fast_zero_fill=fast_zero_fill,
8637
            inference_params=inference_params,
8638
            pad_between_seqs=pad_between_seqs,
8639
8640
        )

8641
        # ===================
8642
        # Output. [sq, b, h]
8643
        # ===================
8644
        projection_output = self.proj(
8645
8646
            context_layer,
            is_first_microbatch=is_first_microbatch,
8647
            fp8_grad=isinstance(context_layer, QuantizedTensor),
8648
8649
        )

8650
8651
8652
8653
8654
8655
8656
8657
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
8658
        if self.input_layernorm and self.return_layernorm_output:
8659
8660
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]