attention.py 400 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
from importlib.metadata import PackageNotFoundError
10
import math
11
import os
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
import warnings
14
import logging
15
import functools
16

17
from dataclasses import dataclass, fields
cyanguwa's avatar
cyanguwa committed
18
import numpy as np
19
from packaging.version import Version as PkgVersion
20
21

import torch
22
import torch.nn.functional as F
23

24
import transformer_engine_torch as tex
25
26
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
27
28
29
30
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
31
32
33
34
35
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
36
37
    fused_attn_fwd,
    fused_attn_bwd,
38
39
40
41
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
42
43
44
45
46
47
48
49
50
51
52
53
54
    META_QKV,
    META_DQKV,
    META_O,
    META_DO,
    META_S,
    META_DP,
    META_O_CP,
    META_DQKV_CP,
)
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    get_fp8_te_dtype,
    get_fp8_torch_dtype,
55
)
56
from transformer_engine.pytorch.float8_tensor import Float8Tensor
57
from transformer_engine.pytorch.module import LayerNormLinear, Linear
58
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
59
60
61
62
63
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
64
    get_default_init_method,
65
66
67
68
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
69
    AttnBiasTypes,
70
    QKVLayouts,
71
    dist_group_type,
72
    TE_DType,
73
74
75
76
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
77
    get_distributed_rank,
78
    checkpoint,
79
80
81
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
82
83
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
84
85
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
86
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
87
88
from transformer_engine.pytorch.graph import is_graph_capturing

89

90
91
92
93
94
95
96
97
98
99
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
100
101
102
103
104
105
106
107
108
109
fa_logger = logging.getLogger()
fa_logger.setLevel(_log_level)
if not fa_logger.hasHandlers():
    fa_logger.addHandler(_stream_handler)


@functools.lru_cache(maxsize=None)
def _get_supported_versions(version_min, version_max):
    return ">= " + str(version_min) + ", " + "<= " + str(version_max)

110

111
112
113
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
114
115
116
117
118

# Detect flash-attn v2 in the environment
_flash_attn_is_installed = False
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
119
_flash_attn_max_version = PkgVersion("2.6.3")
120
121
122
123
124
125
126
_flash_attn_2_plus = False
_flash_attn_2_1_plus = False
_flash_attn_2_3_plus = False
_flash_attn_2_4_plus = False
_flash_attn_2_4_1_plus = False
_flash_attn_2_5_7_plus = False
_flash_attn_2_6_0_plus = False
127
128
129
130
131
132
133

flash_attn_func = None
flash_attn_varlen_func = None
flash_attn_varlen_fwd = None
flash_attn_varlen_bwd = None
flash_attn_cuda_bwd = None

134
135
136
try:
    _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
137
    if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN:
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        fa_logger.debug(
            "flash-attn v2 is not installed. To use, please install it by"
            """ "pip install flash-attn".""",
        )
else:
    if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
        from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
        from flash_attn.flash_attn_interface import (
            _flash_attn_varlen_forward as flash_attn_varlen_fwd,
        )
        from flash_attn.flash_attn_interface import (
            _flash_attn_varlen_backward as flash_attn_varlen_bwd,
        )
        from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd

        _flash_attn_is_installed = True
        _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
        _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
        _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
        _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
        _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
        _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
        _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0")
161
162
163
    elif (
        torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN
    ):
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        fa_logger.warning(
            "Supported flash-attn versions are %s. Found flash-attn %s.",
            _get_supported_versions(
                _flash_attn_version_required,
                _flash_attn_max_version,
            ),
            _flash_attn_version,
        )

# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
_flash_attn_3_is_installed = False
_flash_attn_3_version = PkgVersion("0")
_flash_attn_3_0_0_beta = False
179
_use_flash_attn_3 = False
180
181
182
183
184
_flash_attn_3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
185
try:
186
    _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
187
except PackageNotFoundError:
188
    if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN:
189
190
        fa_logger.debug(
            "flash-attn v3 is not installed. To use, please install it by \n%s",
191
            _flash_attn_3_installation_steps,
192
        )
193
194
195
196
197
else:
    from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
    from flashattn_hopper.flash_attn_interface import (
        flash_attn_varlen_func as flash_attn_varlen_func_v3,
    )
198
199
200
201
202
203
    from flashattn_hopper.flash_attn_interface import (
        _flash_attn_varlen_forward as flash_attn_varlen_fwd_v3,
    )
    from flashattn_hopper.flash_attn_interface import (
        _flash_attn_varlen_backward as flash_attn_varlen_bwd_v3,
    )
204

205
206
    _flash_attn_3_is_installed = True
    _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
207
    _use_flash_attn_3 = True
208

209
210
211
212
213
214
215
_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
216
}
217
218


219
220
@dataclass(eq=True)
class AttentionParams:
221
    """
222
    Attention parameters used to determine which backend to be used.
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241

    Parameters
    ----------
    qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
        Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
    qkv_dtype: torch.dtype, default = `torch.bfloat16`
        Data type of query/key/value tensors.
    qkv_layout: str, default = "sbh3d"
        Query/key/value tensor memory layout.
    batch_size: int, default = 1
        Batch size.
    num_heads: int, default = 16
        Number of attention heads in the query tensor.
    num_gqa_groups: int, default = 16
        Number of attention heads in key and value tensors.
    max_seqlen_q: int, default = 128
        Maximum sequence length of the query tensor.
    max_seqlen_kv: int, default = 128
        Maximum sequence length of the key and value tensors.
242
243
244
245
    head_dim_qk: int, default = 64
        The size of each attention head in query and key tensors.
    head_dim_v: int, default = 64
        The size of each attention head in the value tensor.
246
247
248
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
        `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
249
    window_size: Tuple[int, int], default = None
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        Sliding window attention size.
    alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
        Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
    core_attention_bias_type: str, default = `no_bias`
        Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
    core_attention_bias_shape: str, default = `1hss`
        Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
    core_attention_bias_requires_grad: bool, default = `True`
        Whether attention bias requires gradient.
    pad_between_seqs: bool, default = `False`
        Whether there is padding between sequences in a batch.
        This only applies to `qkv_format=thd`.
    attention_dropout: float, default = 0.0
        Attention dropout.
    context_parallel: bool, default = `False`
        Whether context parallelism is used or not.
    deterministic: bool, default = `False`
        Whether to run `DotProductAttention` with determinism or not.
268
269
    is_training: bool, default = `True`
        Whether in training mode (`True`) or inference mode (`False`)
270
271
272
273
    fp8: bool, default = `False`
        Whether `DotProductAttention` is in an `fp8_autocast` region.
    fp8_meta: Optional[Dict[str Any]], default = `None`
        The FP8 metadata tensor of `DotProductAttention`.
274
275
276
277
278
279
280
281
282
283
    """

    qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
    qkv_dtype: torch.dtype = torch.bfloat16
    qkv_layout: str = "sbh3d"
    batch_size: int = 1
    num_heads: int = 16
    num_gqa_groups: int = 16
    max_seqlen_q: int = 128
    max_seqlen_kv: int = 128
284
285
    head_dim_qk: int = 64
    head_dim_v: int = 64
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    attn_mask_type: str = "no_mask"
    window_size: Union[Tuple[int, int], None] = None
    alibi_slopes_shape: Union[torch.Size, List, None] = None
    core_attention_bias_type: str = "no_bias"
    core_attention_bias_shape: str = "1hss"
    core_attention_bias_requires_grad: bool = True
    pad_between_seqs: bool = False
    attention_dropout: float = 0.0
    context_parallel: bool = False
    deterministic: bool = False
    is_training: bool = True
    fp8: bool = False
    fp8_meta: Union[Dict[str, Any], None] = None


_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


316
317
318
319
320
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
    """Make tensor contiguous if final stride is not 1."""
    return tensor.contiguous() if tensor.stride(-1) != 1 else tensor


321
322
323
324
325
326
327
328
329
def get_attention_backend(
    attention_params: AttentionParams = None,
):
    """
    Select the appropriate attention backend/sub-backend based on user input and runtime environment.

    Parameters
    ----------
    See `AttentionParams`.
330
331
332
333
334
335
336

    Returns
    ----------
    use_flash_attention: bool
        Whether the `FlashAttention` backend has been selected.
    use_fused_attention: bool
        Whether the `FusedAttention` backend has been selected.
337
338
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend
        If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
339
340
341
342
343
344
    use_unfused_attention: bool
        Whether the `UnfusedDotProductAttention` backend has been selected.
    available_backends: List[bool]
        All available backends that could support the provided input. A list of Booleans
        in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
    """
345
346
347
348
349
350
351
352
    qkv_type = attention_params.qkv_type
    qkv_dtype = attention_params.qkv_dtype
    qkv_layout = attention_params.qkv_layout
    batch_size = attention_params.batch_size
    num_heads = attention_params.num_heads
    num_gqa_groups = attention_params.num_gqa_groups
    max_seqlen_q = attention_params.max_seqlen_q
    max_seqlen_kv = attention_params.max_seqlen_kv
353
354
    head_dim_qk = attention_params.head_dim_qk
    head_dim_v = attention_params.head_dim_v
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    attn_mask_type = attention_params.attn_mask_type
    window_size = attention_params.window_size
    alibi_slopes_shape = attention_params.alibi_slopes_shape
    core_attention_bias_type = attention_params.core_attention_bias_type
    core_attention_bias_shape = attention_params.core_attention_bias_shape
    core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
    pad_between_seqs = attention_params.pad_between_seqs
    attention_dropout = attention_params.attention_dropout
    context_parallel = attention_params.context_parallel
    deterministic = attention_params.deterministic
    is_training = attention_params.is_training
    fp8 = attention_params.fp8
    fp8_meta = attention_params.fp8_meta

    # Run config
370
    logger = logging.getLogger("DotProductAttention")
371
372
373
    logger.setLevel(_log_level)
    if not logger.hasHandlers():
        logger.addHandler(_stream_handler)
374
375
376
377
378
    device_compute_capability = get_device_compute_capability()
    cudnn_version = get_cudnn_version()
    run_config = {
        "transformer_engine_version": te.__version__,
        "compute_capability": "sm"
379
        + str(10 * device_compute_capability[0] + device_compute_capability[1]),
380
381
382
383
384
385
        "flash_attn_version": (
            str(_flash_attn_version) if _flash_attn_is_installed else "not installed"
        ),
        "flash_attn_3_version": (
            str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed"
        ),
386
387
388
389
390
391
392
393
394
        "cudnn_version": ".".join([str(i) for i in cudnn_version]),
    }
    attention_params_dict = {
        field.name: getattr(attention_params, field.name) for field in fields(attention_params)
    }
    run_config.update(attention_params_dict)
    if fp8:
        run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
    logger.debug("Running with config=%s", run_config)
395

396
397
398
399
400
401
    # The following sections check if `FlashAttention` supports the provided attention params,
    # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
    # necessary for performance/functionality, a warning will be issued to prompt users to
    # install an appropriate FA version.
    global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3

402
    # Filter: Environment variables
403
404
405
406
    use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
    use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
    use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
    if not use_flash_attention and _flash_attn_is_installed:
407
408
409
410
411
412
413
414
        logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
    if not use_fused_attention:
        logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
    if not use_unfused_attention:
        logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")

    # Filter: ONNX mode
    if is_in_onnx_export_mode():
415
        if use_flash_attention and _flash_attn_is_installed:
416
417
418
419
420
421
422
423
            logger.debug("Disabling FlashAttention due to ONNX mode")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention due to ONNX mode")
        use_fused_attention = False

    # Filter: Compute capability
    if device_compute_capability < (8, 0):
424
        if use_flash_attention and _flash_attn_is_installed:
425
            logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
426
        use_flash_attention = False
427
428
429
        if use_fused_attention:
            logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
            use_fused_attention = False
430
    if device_compute_capability < (9, 0):
431
        if use_flash_attention and _flash_attn_3_is_installed:
432
            logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
433
        _use_flash_attn_3 = False
434
435

    # Filter: Data type
436
437
438
439
    if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
        torch.Tensor,
        Float8Tensor,
    ]:
440
        if use_flash_attention and _flash_attn_is_installed:
441
442
443
444
445
446
            logger.debug(
                "Disabling FlashAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
447
        use_flash_attention = False
448
449
450
451
452
453
454
455
        if use_fused_attention:
            logger.debug(
                "Disabling FusedAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
            use_fused_attention = False
456
457
458

    # Filter: Execution type
    if fp8 and fp8_meta["recipe"].fp8_dpa:
459
        if use_flash_attention and not _use_flash_attn_3:
460
461
            if _flash_attn_is_installed:
                logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
462
463
464
465
466
            use_flash_attention = False
        if use_flash_attention and _use_flash_attn_3 and is_training:
            logger.debug(
                "Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
            )
467
468
469
470
471
472
            use_flash_attention = False
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
            use_unfused_attention = False

    # Filter: Head dimension
473
    if use_flash_attention and head_dim_qk != head_dim_v:
474
475
        if _flash_attn_is_installed:
            logger.debug("Disabling FlashAttention as it does not support MLA.")
476
        use_flash_attention = False
477
    if use_flash_attention and (
478
479
480
        head_dim_qk > 256
        or head_dim_qk % 8 != 0
        or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
481
    ):
482
483
484
485
486
487
488
489
490
491
        if _flash_attn_is_installed:
            logger.debug(
                "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
                "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
                "head_dim_qk <= 256 (>192 requires sm80/90). "
                "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
                head_dim_qk,
                head_dim_v,
                ".".join([str(i) for i in device_compute_capability]),
            )
492
        use_flash_attention = False
493
494
495
496
497
498
499
    qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
    if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
        logger.debug(
            "Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
            qkv_layout,
        )
        use_fused_attention = False
500
501
502
503
504
505
506
507

    # Filter: QKV layout
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
            use_unfused_attention = False
        if use_flash_attention and pad_between_seqs:
508
509
510
511
512
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention for qkv_format = thd when there is "
                    "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
                )
513
514
            use_flash_attention = False

515
    # Filter: Dropout
516
517
518
    if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
        logger.debug("Disabling FlashAttention 3 for dropout")
        _use_flash_attn_3 = False
519

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    # Filter: Context parallelism
    # qkv_format | attn_mask_type              | attn_bias_type           | supported backends
    # ----------------------------------------------------------------------------------------------------
    # bshd, sbhd | self-attention:             | no_bias, post_scale_bias | FlashAttention, FusedAttention
    #            |     no_mask, causal         |                          |
    #            | cross-attention:            |                          |
    #            |     no_mask                 |                          |
    # thd        | self-attention:             | no_bias                  | FlashAttention, FusedAttention
    #            |     padding, padding_causal |                          | if no padding between sequences,
    #            | cross-attention:            |                          | FusedAttention
    #            |     padding                 |                          | if there is padding between sequences
    # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
    if context_parallel and use_unfused_attention:
        logger.debug(
            "Disabling UnfusedDotProductAttention as it does not support context parallelism"
        )
        use_unfused_attention = False
    if context_parallel and use_flash_attention:
538
        if fp8 and fp8_meta["recipe"].fp8_dpa:
539
540
541
542
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with FP8"
                )
543
            use_flash_attention = False
544
        if "bottom_right" in attn_mask_type:
545
546
547
548
549
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " causal_bottom_right masking"
                )
550
551
            use_flash_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
552
553
554
555
556
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " causal masking for cross-attention"
                )
557
558
            use_flash_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
559
560
561
562
563
564
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with bias"
                    " type of %s",
                    core_attention_bias_type,
                )
565
566
            use_flash_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
567
568
569
570
571
            if _flash_attn_is_installed:
                logger.debug(
                    "Disabling FlashAttention as it does not support context parallelism with"
                    " attention bias for THD format"
                )
572
            use_flash_attention = False
573

574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    if context_parallel and use_fused_attention:
        if "bottom_right" in attn_mask_type:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with"
                " causal_bottom_right masking"
            )
            use_fused_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with causal"
                " masking for cross-attention"
            )
            use_fused_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with bias type"
                " of %s",
                core_attention_bias_type,
            )
            use_fused_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with attention"
                " bias for THD format"
            )
            use_fused_attention = False
        elif head_dim_qk != head_dim_v:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with MLA"
            )
            use_fused_attention = False

606
    # Filter: Attention mask
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    # attn_mask_type              | attention_mask                       | supported backends
    # ----------------------------------------------------------------------------------------
    # no_mask                     | None                                 | All
    # padding                     |                                      | All
    #     self-attention          | One tensor in shape [b, 1, 1, sq]    |
    #     cross-attention         | Tuple of two tensors in shapes       |
    #                             | [b, 1, 1, sq] and [b, 1, 1, skv]     |
    # causal                      | None                                 |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # padding_causal              | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # causal_bottom_right         | None                                 | All
    # padding_causal_bottom_right | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FlashAttention, UnfusedDotProductAttention
    # arbitrary                   | One tensor in shape broadcastable to | UnfusedDotProductAttention
    #                             | [b, h, sq, skv]                      |
626
    if attn_mask_type == "arbitrary":
627
        if use_flash_attention and _flash_attn_is_installed:
628
629
630
631
632
            logger.debug("Disabling FlashAttention for arbitrary mask")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention for arbitrary mask")
        use_fused_attention = False
633
634
    if (
        use_flash_attention
635
        and _use_flash_attn_3
636
637
638
639
640
641
642
643
644
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        _use_flash_attn_3 = False
645
646
647
648
649
    if (
        use_flash_attention
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
650
651
652
653
654
655
656
657
658
        if _flash_attn_2_1_plus:
            logger.warning(
                "Disabling FlashAttention as it only supports bottom-right-diagonal "
                "causal mask since flash-attn 2.1. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False
        if not _flash_attn_is_installed:
            _flash_attn_max_version = PkgVersion("2.1")
659
660
661
662
663
    if (
        use_flash_attention
        and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
        and max_seqlen_q != max_seqlen_kv
    ):
664
665
666
667
668
669
670
671
672
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.1")
        elif not _flash_attn_2_1_plus and not _use_flash_attn_3:
            logger.warning(
                "Disabling FlashAttention as it only supports top-left-diagonal "
                "causal mask before flash-attn 2.1. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False
673
674
675
676
677
678
679
680
681
    if (
        use_flash_attention
        and _use_flash_attn_3
        and fp8
        and fp8_meta["recipe"].fp8_dpa
        and "padding" in attn_mask_type
    ):
        logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
        _use_flash_attn_3 = False
682
683

    # Filter: Sliding window attention
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    #    backend                 |      window_size       | diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | (-1, -1) or (>=0, >=0) | bottom right
    # FusedAttention             | (-1,  0) or (>=0, 0)   | top left
    # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
    #                            |                        | converts window_size to an 'arbitrary' mask
    if window_size is None:
        window_size = check_set_window_size(attn_mask_type, window_size)
    else:
        if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention"
                    " for FP8"
                )
                use_fused_attention = False
            elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
                logger.debug(
                    "Disabling FusedAttention as it only supports sliding window attention "
                    "with causal mask, no dropout, and qkv_format = bshd/sbhd"
                )
                use_fused_attention = False
            elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
                "no_mask",
                "padding",
                "causal_bottom_right",
                "padding_causal_bottom_right",
            ]:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s for cross-attention",
                    attn_mask_type,
                )
                use_fused_attention = False
            elif "padding" in attn_mask_type:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s",
                    attn_mask_type,
                )
                use_fused_attention = False
725
726
727
728
729
730
731
732
733
734
735
736
737
        if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if _use_flash_attn_3:
                logger.debug(
                    "Disabling FlashAttention 3 as it does not support sliding window attention"
                )
                _use_flash_attn_3 = False
            if not _flash_attn_is_installed:
                _flash_attn_version_required = PkgVersion("2.3")
            elif not _flash_attn_2_3_plus:
                logger.debug(
                    "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
                )
                use_flash_attention = False
738
739

    # Filter: Attention bias
740
741
742
743
744
745
746
747
    #    backend                 |      bias types              | ALiBi diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | no_bias, alibi/alibi_slopes  | bottom right
    # FusedAttention             | no_bias, post_scale_bias     |
    #                            | alibi/alibi_slopes           | top left,
    #                            |                              | bottom_right (converts to a 'post_scale_bias' bias)
    # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
    #                            | alibi/alibi_slopes           | both; converts to a 'post_scale_bias' bias
748
    if use_flash_attention and core_attention_bias_type == "alibi":
749
        if _use_flash_attn_3:
750
751
            logger.debug("Disabling FlashAttention 3 for ALiBi")
            _use_flash_attn_3 = False
752
753
754
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.4")
        elif not _flash_attn_2_4_plus:
755
756
            logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
            use_flash_attention = False
757

758
759
760
761
    if use_flash_attention and (
        core_attention_bias_type not in ["no_bias", "alibi"]
        or core_attention_bias_shape is not None
    ):
762
763
        if _flash_attn_is_installed:
            logger.debug("Disabling FlashAttention for pre/post_scale_bias")
764
765
766
767
768
769
770
771
        use_flash_attention = False

    fu_core_attention_bias_type = core_attention_bias_type
    fu_core_attention_bias_shape = core_attention_bias_shape
    fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
    if (
        use_fused_attention
        and core_attention_bias_type == "alibi"
772
        and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
773
774
775
    ):
        fu_core_attention_bias_type = "post_scale_bias"
        fu_core_attention_bias_requires_grad = False
776
777
778
779
780
        if alibi_slopes_shape is None:
            fu_core_attention_bias_shape = "1hss"
        elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
            fu_core_attention_bias_shape = "1hss"
        elif (
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
            len(alibi_slopes_shape) == 2
            and alibi_slopes_shape[0] == batch_size
            and alibi_slopes_shape[1] == num_heads
        ):
            fu_core_attention_bias_shape = "bhss"

    if (
        use_fused_attention
        and fu_core_attention_bias_type == "post_scale_bias"
        and fu_core_attention_bias_shape != "1hss"
    ):
        if fu_core_attention_bias_requires_grad:
            # remove this line when cuDNN adds bwd support for
            # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
            logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
            use_fused_attention = False
        else:
            # max512 backend will only support [1, h, s, s]
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

    # Filter: cuDNN support
    fused_attention_backend = None
    if use_fused_attention:
        q_type = TE_DType[qkv_dtype]
        kv_type = q_type
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            kv_type = q_type
        fused_attention_backend = tex.get_fused_attn_backend(
            q_type,
            kv_type,
            QKVLayout[qkv_layout],
            AttnBiasType[fu_core_attention_bias_type],
            AttnMaskType[attn_mask_type],
            attention_dropout,
            num_heads,
            num_gqa_groups,
            max_seqlen_q,
            max_seqlen_kv,
820
821
            head_dim_qk,
            head_dim_v,
822
823
            window_size[0],
            window_size[1],
824
        )
825
        if fused_attention_backend == FusedAttnBackend["No_Backend"]:
826
827
            logger.debug("Disabling FusedAttention as no backend supports the provided input")
            use_fused_attention = False
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
            fused_attention_backend = None
        if (
            use_fused_attention
            and window_size is not None
            and window_size[0] != -1
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "slidng window attention",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
845
846
847
848
849
850
851
852
            and fu_core_attention_bias_type == "post_scale_bias"
            and fu_core_attention_bias_shape != "1hss"
        ):
            logger.debug(
                "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
                " [1, H, S, S] shape"
            )
            use_fused_attention = False
853
            fused_attention_backend = None
854
855
856
857
858
859
860
861
862
863
864
865
866

    # Filter: Determinism
    # backend                      | deterministic
    # ---------------------------------------------
    # FlashAttention               |
    #     flash-attn >=2.0, <2.4.1 | no
    #     flash-attn >=2.4.1       | yes
    # FusedAttention               |
    #     sub-backend 0            | yes
    #     sub-backend 1            | workspace optimization path and sm90+: yes;
    #                              | otherwise: no
    #     sub-backend 2            | no
    # UnfusedDotProductAttention   | yes
867
868
869
870
871
872
873
874
875
876
    if use_flash_attention and deterministic:
        if not _flash_attn_is_installed:
            _flash_attn_version_required = PkgVersion("2.4.1")
        elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3:
            logger.warning(
                "Disabling FlashAttention as version <2.4.1 does not support deterministic "
                "execution. To use FlashAttention with deterministic behavior, "
                "please install flash-attn >= 2.4.1."
            )
            use_flash_attention = False
877
878
879
880
881
882
883
884
885
886
887
    if use_fused_attention and deterministic:
        if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (
                device_compute_capability < (9, 0)
                or core_attention_bias_requires_grad
                or cudnn_version < (8, 9, 5)
888
            )
889
890
891
        ):
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
892
893
894

    # All available backends
    available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911

    # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
    # When `FusedAttention` does not support the provided attention params, and `FlashAttention`
    # does, we recommend users to install flash-attn if not installed already.
    if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed:
        logger.warning(
            "flash-attn may provide important feature support or performance improvement."
            " Please install flash-attn %s.",
            _get_supported_versions(
                _flash_attn_version_required,
                _flash_attn_max_version,
            ),
        )
    if use_flash_attention and not _flash_attn_is_installed:
        use_flash_attention = False
        available_backends[0] = False

912
913
914
915
916
917
918
919
920
921
922
923
    logger.debug(
        "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
        " UnfusedDotProductAttention=%s}",
        bool(available_backends[0]),
        bool(available_backends[1]),
        (
            f" (sub-backend {int(fused_attention_backend)})"
            if fused_attention_backend is not None
            else ""
        ),
        bool(available_backends[2]),
    )
924
925
926
927
928
929
930
931
932
933
934
935
936

    # Select FusedAttention for performance
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
    ):
        if device_compute_capability == (9, 0):
            logger.debug(
                "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                "for performance reasons"
            )
            use_flash_attention = False
937
938
939
940
941
942
943
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["FP8"]
        and _use_flash_attn_3
    ):
        logger.debug(
944
945
            "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
            "in FP8 execution"
946
947
948
        )
        use_flash_attention = False

949
950
951
952
953
954
    # Selected backend
    if use_flash_attention:
        use_fused_attention = False
        use_unfused_attention = False
    elif use_fused_attention:
        use_unfused_attention = False
955
    selected_backend = "NoBackend"
956
957
958
959
960
961
    if use_flash_attention:
        selected_backend = "FlashAttention"
    elif use_fused_attention:
        selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
    elif use_unfused_attention:
        selected_backend = "UnfusedDotProductAttention"
962
    logger.debug("Selected backend = %s", selected_backend)
963

964
965
966
967
968
969
    global _attention_backends
    _attention_backends["use_flash_attention"] = use_flash_attention
    _attention_backends["use_fused_attention"] = use_fused_attention
    _attention_backends["fused_attention_backend"] = fused_attention_backend
    _attention_backends["use_unfused_attention"] = use_unfused_attention
    _attention_backends["backend_selection_requires_update"] = False
970
971
972
973

    return (
        use_flash_attention,
        use_fused_attention,
974
        fused_attention_backend,
975
976
977
978
979
        use_unfused_attention,
        available_backends,
    )


980
class InferenceParams:  # pylint: disable=too-few-public-methods
981
982
    """
    Inference parameters that are passed to the main model in order
983
    to efficiently calculate and store the context during inference.
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
1024

1025

1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
@torch.no_grad()
def get_swa_mask(
    window_size: Tuple[int, int],
    max_seqlen_q: int,
    max_seqlen_kv: int,
    attn_mask_type: str = "no_mask",
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
    """
    Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
    For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
    and for other mask types, the bottom right corner.

    Parameters
    ----------
    window_size: Tuple[int, int]
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
        window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
        map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
        `attn_mask_type`.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
        "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
        default = `None`
        Boolean tensor(s) used to mask out attention softmax input.

    Returns
    ----------
    attention_mask: torch.Tensor
        Combined `attention_mask` (input) and sliding window attention mask.
        The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
        else, the same shape as input `attention_mask`.
    """
    mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
    if attn_mask_type in ["causal"]:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_q
        right = window_size[1] if window_size[1] != -1 else max_seqlen_q
        mask_upper = torch.triu(mask, diagonal=-left)
        mask_lower = torch.tril(mask_upper, diagonal=right)
    else:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
        right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
        mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
        mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
    attn_mask_type = "arbitrary"
    mask = mask_lower.logical_not()
    if attention_mask is not None:
        mask = torch.logical_and(attention_mask, mask)
    return attn_mask_type, mask


1084
1085
1086
1087
1088
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
1089
1090
    actual_seqlens_q: Optional[torch.Tensor] = None,
    actual_seqlens_kv: Optional[torch.Tensor] = None,
1091
1092
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
1093
    bottom_right_alignment: bool = True,
1094
) -> Tuple[torch.Tensor, torch.Tensor]:
1095
    """
1096
1097
1098
1099
1100
1101
1102
1103
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
1104
1105
1106
1107
    actual_seqlens_q: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for queries, in shape [batch_size].
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for keys and values, in shape [batch_size].
1108
1109
1110
1111
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
1112
1113
1114
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the ALiBi bias to the bottom right corner of
        the matrix (`True`) or top left (`False`).
1115

1116
1117
1118
1119
1120
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
1121
1122
1123
1124
1125
1126
        ALiBi bias in FP32 or `bias_dtype`. Its shape is
        (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
        and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
        (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
        [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
        `actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
1150
        elif _alibi_cache["_alibi_slopes"].dim() == 2:
1151
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
1152
1153
1154
        else:
            raise ValueError("ALiBi slopes cannot exceed 2 dimensions.")

1155
        bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1156
            1, 1, max_seqlen_q, 1
1157
1158
        ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
            1, 1, 1, max_seqlen_kv
1159
        )
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
        if actual_seqlens_q is None and actual_seqlens_kv is None:
            if bottom_right_alignment:
                bias = bias + max_seqlen_kv - max_seqlen_q
        elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
            batch_size = actual_seqlens_q.shape[0]
            bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
            if bottom_right_alignment:
                bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
        else:
            assert (
                False
            ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
1172
1173
1174
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
1175
        _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
1176
1177
1178
1179
1180
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
1181
1182
1183
1184
1185
1186
1187
1188
1189


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
1190
    reduced_mask = mask.logical_not().sum(dim=1)
1191
1192
1193
1194
1195
1196
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

1197

1198
1199
1200
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
1201
1202
1203
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
1204
1205
1206
1207
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

1208
    reduced_mask = mask.logical_not().sum(dim=1)
1209
1210
1211
1212
1213
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
1214
    indices = mask.logical_not().nonzero()
1215
1216
1217
1218
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
1219
1220
1221
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
1222
1223
1224
1225

    return cu_seqlens, indices


1226
1227
1228
1229
1230
1231
1232
1233
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
1234
1235
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
1236
1237
1238

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
1239
1240
1241
1242
1243
1244
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
1245
1246
1247

    return indices

1248

1249
_cu_seqlens_cache = {}
1250
1251


1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
1272
1273


1274
@torch.compile
1275
1276
1277
1278
1279
1280
1281
1282
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
1283
1284
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1285
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
1286
1287
1288
1289
1290
1291
1292
1293
    if isinstance(tensor, Float8Tensor):
        tensor_data = torch.cat((tensor._data, padding_indice), dim=0)

        packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices))
    else:
        tensor = torch.cat((tensor, padding_indice), dim=0)

        packed = torch.gather(tensor, 0, indices)
1294
1295
1296
    return packed


1297
@torch.compile
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


1311
@torch.compile
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


1327
@torch.compile
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
1338
1339
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1340
1341
1342
1343
1344
1345
    if isinstance(tensor, Float8Tensor):
        unpacked.scatter_(0, indices, tensor._data)
        unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :])
    else:
        unpacked.scatter_(0, indices, tensor)
        unpacked = unpacked[0:-1, :, :]
1346
1347
1348
    return unpacked


1349
@torch.compile
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


1364
@torch.compile
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
1385

1386
1387
    @staticmethod
    def forward(
1388
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
1389
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
1390
        # pylint: disable=missing-function-docstring
1391
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
1392
        ctx.save_for_backward(indices)
1393
1394
1395
1396
1397
1398
1399
1400
1401
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
1402
        # pylint: disable=missing-function-docstring
1403
        (indices,) = ctx.saved_tensors
1404
        if len(grad_outputs) == 1:
1405
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
1406
        if len(grad_outputs) == 2:
1407
1408
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
1409
1410
1411
1412
1413
1414


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
1415

1416
1417
1418
1419
1420
1421
1422
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
1423
        # pylint: disable=missing-function-docstring
1424
        ctx.save_for_backward(indices)
1425
1426
1427
1428
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
1429
        # pylint: disable=missing-function-docstring
1430
1431
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
1432
1433


1434
1435
1436
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
1437
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
1438
1439
1440
1441
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
1442
1443
1444
1445
1446
1447
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
1448
1449
1450
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
1451
1452
1453
1454
1455
1456
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


1476
@jit_fuser
1477
1478
1479
1480
1481
def flash_attn_fwd_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
1482
1483
    movedim_src: int,
    movedim_dst: int,
1484
):
1485
    """Merge partial outputs of each step in Attention with context parallelism"""
1486
1487
1488
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(
        movedim_src, movedim_dst
    )
1489
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
1490
    out_corrected = out_per_step * softmax_lse_corrected_exp
1491
1492
1493
    out.add_(out_corrected)


1494
@jit_fuser
1495
1496
1497
1498
def flash_attn_fwd_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
1499
    """Merge softmax stats of each step in Attention with context parallelism"""
1500
1501
1502
1503
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
1504
1505


1506
1507
@jit_fuser
def get_cu_seqlens_on_cp_rank(
1508
1509
1510
1511
1512
1513
    cu_seqlens: torch.Tensor,
    cu_seqlens_padded_on_cp_rank: torch.Tensor,
    cp_size: int,
    cp_rank: int,
    first_half: bool,
    second_half: bool,
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
@torch.compile
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
    To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
    before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
    sequence chunk ids for reordering.
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
    if to_contiguous:
        for rank in range(cp_size):
            chunk_ids[rank] = 2 * rank
            chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
    else:
        for rank in range(cp_size):
            chunk_ids[2 * rank] = rank
            chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
    return chunk_ids


@torch.compile
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
    """Reorder sequence chunk for A2A communication."""
    if before_attn:
        # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
        # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
        x = x.movedim(0, seq_dim).contiguous()
        # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
        # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
        x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
        # reorder the sequence chunks
        x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
    else:
        # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
        # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
        x = x.movedim(seq_dim, 0).contiguous()
        # reorder the sequence chunks
        x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
        # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
        # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
        x = x.view(cp_size, 2, *x.shape[1:])
    return x


def flash_attn_a2a_communicate(
    a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
    chunk_ids_for_a2a: torch.Tensor,
    seq_dim: int,
    cp_size: int,
    cp_group: dist_group_type,
    cp_stream: torch.cuda.Stream,
    before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
    """A2A communication for context parallelism."""
    a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
    a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
    if before_attn:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # reorder the sequence chunks
                    x = reorder_seq_chunks_for_a2a(
                        x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
                    )
                    # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
                    # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
                    a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, s, np, hn] -> [b, s, cp, np//cp, hn]
                # or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
                x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
                # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
                # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
                a2a_inputs[i] = x.movedim(-3, 0).contiguous()
    else:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
                # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
                x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
                # reorder the sequence chunks
                a2a_inputs[i] = reorder_seq_chunks_for_a2a(
                    x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
                    # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
                    x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
                    # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
                    # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
                    a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
    torch.cuda.current_stream().wait_stream(cp_stream)
    return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs


1644
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
1645
    """
1646
1647
1648
    Attention implementation with context parallelism. Exchange KV between CP ranks
    with P2P in ring topology. Split attention compute into multiple steps, and overlap
    current-step compute with next-step communication.
1649
1650
1651
1652
1653

    This implementation also supports hierarchical CP, which parallelizes attention
    heads in low-level CP groups and parallelizes sequence dimension in high-level CP
    groups. For more details, please refer to `LongVILA <https://arxiv.org/abs/2408.10188>`_
    and `USP <https://arxiv.org/abs/2405.07719>`_.
1654
1655
1656
    """

    @staticmethod
1657
1658
1659
1660
1661
1662
1663
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
1664
        cu_seqlens_kv,
1665
        max_seqlen_q,
1666
        max_seqlen_kv,
1667
1668
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
1669
1670
1671
1672
1673
1674
1675
1676
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
1677
1678
        fp8,
        fp8_meta,
1679
1680
1681
        cp_group,
        cp_global_ranks,
        cp_stream,
1682
    ):
1683
        # pylint: disable=missing-function-docstring
1684
1685
1686
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
        if isinstance(cp_group, list):
            assert (
                qkv_format != "thd"
            ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
            assert attn_bias_type == "no_bias", (
                f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
                " yet!"
            )
            cp_group_a2a = cp_group[0]
            cp_size_a2a = get_distributed_world_size(cp_group_a2a)
            rank_a2a = get_distributed_rank(cp_group_a2a)
            cp_group = cp_group[1]
        else:
            cp_group_a2a = None
            cp_size_a2a = 1
            rank_a2a = 0

1704
1705
        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
1706
1707
        send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
1708
1709
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1710
1711
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
1712

1713
        seq_dim = None
1714
        if qkv_format in ["bshd", "sbhd"]:
1715
            seq_dim = qkv_format.index("s")
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
        pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
        cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
        cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
1728

1729
1730
1731
        fused_attn_qkv_dtype = None
        fused_attn_backend = None
        amax_per_step = None
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
        if fp8:
            if use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_backend = FusedAttnBackend["FP8"]
                if fp8_meta["recipe"].fp8_mha:
                    assert (
                        isinstance(q, Float8Tensor)
                        and isinstance(k, Float8Tensor)
                        and isinstance(v, Float8Tensor)
                    ), "q/k/v must be Float8Tensors for FP8 MHA!"
                    fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                else:
                    q_f16, k_f16, v_f16 = q, k, v
                    if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                        q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                    if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                        k, v = [
                            cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                            for x in [k_f16, v_f16]
                        ]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
                fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            q_f16 = q
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True)
            q, k, v = flash_attn_a2a_communicate(
                [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
            )
            if not fp8:
                q_f16 = q
            elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                q_f16 = q
                q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)

1785
1786
1787
        assert qkv_format == "thd" or (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"
1788
        if causal:
1789
1790
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
1791
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
1792
1793
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
1794
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
1795
1796
1797
        total_tokens_kv = None if qkv_format != "thd" else k.shape[0]
        # remove padded tokens at the end
        k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]]
1798
        if attn_bias is not None:
1799
            assert len(attn_bias.shape) == 4, (
1800
1801
1802
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
1803
1804
1805
            assert (
                attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
            ), "Sequence length does not meet divisible requirements!"
1806
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
1807
1808
1809
1810
1811
1812
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
1813
1814
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
1815
1816
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
1817
            )
1818
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
1819
1820
1821
1822

        softmax_lse_in_packed_format = not use_fused_attention and (
            _flash_attn_2_6_0_plus or _use_flash_attn_3
        )
1823
        flash_attn_fwd = None
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
                flash_attn_fwd = flash_attn_varlen_fwd_v3
                fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
            else:
                flash_attn_fwd = flash_attn_varlen_fwd
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
                if _flash_attn_2_3_plus:
                    fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_5_7_plus:
                    fa_forward_kwargs["block_table"] = None
1839

1840
1841
1842
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
1843
        attn_bias_inputs = [None, None]
1844
1845
1846
1847
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
1848
        attn_biases = [None for _ in range(cp_size)]
1849
1850
1851
1852
1853
1854
1855

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
1856
1857
1858
1859
        if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
1860
1861
        send_recv_reqs = [[], []]

1862
1863
        softmax_lse_ = None
        out = None
1864
        for i in range(cp_size + 1):
1865
            if i < cp_size:
1866
                with torch.cuda.stream(flash_attn_streams[i % 2]):
1867
                    # wait until KV is received
1868
                    for req in send_recv_reqs[(i + 1) % 2]:
1869
1870
                        req.wait()

1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
                    if (
                        not fp8
                        or fp8_meta["recipe"].fp8_mha
                        or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
                    ):
                        kv_inputs[i % 2] = p2p_comm_buffers[i]
                    else:
                        # KV exchange is in BF16/FP16, cast received KV in each step
                        kv_inputs[i % 2] = cast_to_fp8(
                            p2p_comm_buffers[i],
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                        )
                    if fp8 and use_fused_attention:
1898
1899
1900
1901
                        fp8_meta_kwargs["amax_s"] = amax_per_step
                        fp8_meta_kwargs["amax_s_offset"] = i
                        fp8_meta_kwargs["amax_o"] = amax_per_step
                        fp8_meta_kwargs["amax_o_offset"] = cp_size + i
1902
1903
                    if causal:
                        if i == 0:
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1916
                            if use_fused_attention:
1917
1918
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1919
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1920
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
1921
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
1922
                                        k.shape[0], -1, 2, *k.shape[-2:]
1923
                                    )
1924
1925
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1926
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1927
1928
1929
1930
                                    # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        -1, k.shape[2], 2, *k.shape[-2:]
                                    )
1931
                                elif qkv_format == "thd":
1932
                                    q_inputs[i % 2] = q
1933
1934
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1935
1936
1937
1938
1939
1940
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1941
                                    ).contiguous()
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
1970
                                )
1971
1972
1973
1974
1975
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1976
1977
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1978
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1979
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
1980
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
1981
                                fa_outputs = flash_attn_fwd(
1982
1983
1984
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1985
1986
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1987
                                    max_seqlen_q,
1988
                                    max_seqlen_kv,
1989
                                    causal=True,
1990
                                    **fa_forward_kwargs,
1991
                                )
1992
1993
1994
1995
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[7]
1996
                        elif i <= rank:
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
2014
                            if use_fused_attention:
2015
2016
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
2017
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
2018
2019
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
2020
2021
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
2022
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
2023
2024
                                    # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous()
2025
                                elif qkv_format == "thd":
2026
                                    q_inputs[i % 2] = q
2027
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
2028
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
2029
                                        kv_inputs[i % 2], cu_seqlens_kv_padded, 0
2030
                                    )
2031
2032
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2033
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv // 2,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=(
                                        None
                                        if cu_seqlens_kv_padded is None
                                        else cu_seqlens_kv_padded // 2
                                    ),
                                    **fp8_meta_kwargs,
2066
                                )
2067
2068
2069
2070
2071
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2072
2073
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
2074
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
2075
2076
                                if qkv_format == "thd":
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
2077
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
2078
                                        kv_inputs[i % 2], cu_seqlens_kv_padded, 0
2079
                                    )
2080
2081
                                else:
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
2082
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
2083
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
2084
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
2085
2086
2087
                                if _use_flash_attn_3 or _flash_attn_2_3_plus:
                                    fa_forward_kwargs["window_size"] = (-1, -1)
                                fa_outputs = flash_attn_fwd(
2088
2089
2090
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
2091
2092
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2093
                                    max_seqlen_q,
2094
                                    max_seqlen_kv // 2,
2095
                                    causal=False,
2096
                                    **fa_forward_kwargs,
2097
                                )
2098
2099
2100
2101
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[7]
2102
                        else:
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2120
                            if use_fused_attention:
2121
2122
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
2123
                                    q_inputs[i % 2] = q[:, 1, ...].contiguous()
2124
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
2125
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
2126
                                        k.shape[0], -1, 2, *k.shape[-2:]
2127
                                    )
2128
2129
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
2130
                                    q_inputs[i % 2] = q[1].contiguous()
2131
2132
2133
2134
                                    # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        -1, k.shape[2], 2, *k.shape[-2:]
                                    )
2135
2136
                                elif qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
2137
2138
2139
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(
                                        q, cu_seqlens_q_padded, 1
                                    )
2140
2141
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2142
2143
2144
2145
2146
2147
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
2148
                                    ).contiguous()
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q // 2,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=(
                                        None
                                        if cu_seqlens_q_padded is None
                                        else cu_seqlens_q_padded // 2
                                    ),
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
2181
                                )
2182
2183
2184
2185
2186
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2187
                            else:
2188
2189
                                if qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
2190
2191
2192
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(
                                        q, cu_seqlens_q_padded, 1
                                    )
2193
2194
                                else:
                                    # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
2195
                                    q_inputs[i % 2] = (
2196
                                        q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
2197
                                    )
2198
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
2199
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
2200
2201
2202
                                if _use_flash_attn_3 or _flash_attn_2_3_plus:
                                    fa_forward_kwargs["window_size"] = (-1, -1)
                                fa_outputs = flash_attn_fwd(
2203
2204
2205
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
2206
2207
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2208
                                    max_seqlen_q // 2,
2209
                                    max_seqlen_kv,
2210
                                    causal=False,
2211
                                    **fa_forward_kwargs,
2212
                                )
2213
2214
2215
2216
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
                                if not _use_flash_attn_3:
                                    rng_states[i] = fa_outputs[7]
2217
                    else:
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
                        if pad_between_seqs_q:
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
                        else:
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                        if pad_between_seqs_kv:
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
                        else:
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2235
                        if use_fused_attention:
2236
2237
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
2238
2239
2240
2241
2242
2243
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
2244
                                ).contiguous()
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
                            out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                is_training,
                                max_seqlen_q,
                                max_seqlen_kv,
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                fused_attn_qkv_dtype,
                                fused_attn_backend,
                                attn_scale=softmax_scale,
                                dropout=dropout_p,
                                qkv_layout=qkv_layout,
                                attn_mask_type=attn_mask_type,
                                attn_bias_type=attn_bias_type,
                                attn_bias=attn_bias_inputs[i % 2],
                                cu_seqlens_q_padded=cu_seqlens_q_padded,
                                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                **fp8_meta_kwargs,
2273
                            )
2274
2275
2276
2277
2278
                            if fp8:
                                softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                            else:
                                softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                attn_biases[i] = rest[0] if len(rest) > 0 else None
2279
                        else:
2280
                            # [b, sq, np, hn] -> [b*sq, np, hn]
2281
                            q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
2282
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
2283
                            kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
2284
                            fa_outputs = flash_attn_fwd(
2285
2286
2287
                                q_inputs[i % 2],
                                kv_inputs[i % 2][0],
                                kv_inputs[i % 2][1],
2288
2289
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
2290
                                max_seqlen_q,
2291
                                max_seqlen_kv,
2292
                                causal=False,
2293
                                **fa_forward_kwargs,
2294
                            )
2295
2296
2297
2298
                            out_per_step[i] = fa_outputs[4]
                            softmax_lse_per_step[i] = fa_outputs[5]
                            if not _use_flash_attn_3:
                                rng_states[i] = fa_outputs[7]
2299
2300
2301
2302

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
2303
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
2304

2305
2306
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
2307
                    softmax_lse_per_step[i - 1].squeeze_(-1)
2308
2309
2310
2311
2312
                if qkv_format != "thd" and softmax_lse_in_packed_format:
                    # [np, t] -> [np, b, sq]
                    softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view(
                        q.shape[-2], q.shape[0], -1
                    )
2313

2314
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
2315
2316
2317
2318
2319
2320
2321
2322
                    if fp8:
                        out_per_step[i - 1] = cast_from_fp8(
                            out_per_step[i - 1],
                            fp8_meta["scaling_fwd"],
                            META_O_CP,
                            fp8_dtype_forward,
                            TE_DType[torch.float32],
                        )
2323
                    if i == 1:
2324
                        out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
2325
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
2326
                        if causal and qkv_format != "thd":
2327
2328
                            # [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format
                            # [np, b, sq] -> [np, b, 2, sq//2] lse in packed format
2329
                            softmax_lse_ = softmax_lse.view(
2330
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
2331
                            )
2332
2333
2334
2335
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
2336
                    else:
2337
                        if qkv_format == "thd":
2338
                            tex.thd_second_half_lse_correction(
2339
2340
2341
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
2342
                                softmax_lse_in_packed_format,
2343
                            )
2344
                        else:
2345
2346
2347
                            flash_attn_fwd_softmax_lse_correction(
                                softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
                            )
2348
2349

                if i < cp_size:
2350
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
2351
2352
2353
2354
2355

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
2356
            out_ = None
2357
            if qkv_format == "bshd":
2358
2359
2360
                out_per_step[i] = out_per_step[i].view(
                    out.shape[0], -1, *out.shape[-2:]
                )  # pylint: disable=used-before-assignment
2361
2362
2363
2364
                out_ = out[:, 1, ...]
            elif qkv_format == "sbhd":
                out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
                out_ = out[1]
2365

2366
            if i <= rank or not causal:
2367
                if qkv_format in ["bshd", "sbhd"]:
2368
2369
2370
2371
2372
                    flash_attn_fwd_out_correction(
                        out.view(*out_per_step[i].shape),
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2373
2374
                        0 if softmax_lse_in_packed_format else 2,
                        2 if softmax_lse_in_packed_format else seq_dim,
2375
                    )
2376
                elif qkv_format == "thd":
2377
2378
2379
2380
2381
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2382
                        cu_seqlens_q_padded,
2383
                        False,
2384
                        softmax_lse_in_packed_format,
2385
                    )
2386
            else:
2387
                if qkv_format in ["bshd", "sbhd"]:
2388
2389
2390
2391
2392
                    flash_attn_fwd_out_correction(
                        out_,
                        out_per_step[i],
                        softmax_lse_[..., 1, :],
                        softmax_lse_per_step[i],
2393
2394
                        0 if softmax_lse_in_packed_format else 2,
                        2 if softmax_lse_in_packed_format else seq_dim,
2395
                    )
2396
                elif qkv_format == "thd":
2397
2398
2399
2400
2401
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2402
                        cu_seqlens_q_padded,
2403
                        True,
2404
                        softmax_lse_in_packed_format,
2405
                    )
2406

2407
2408
2409
        if qkv_format != "thd" and softmax_lse_in_packed_format:
            # [np, b, sq] -> [np, t]
            softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1)
2410
        kv = p2p_comm_buffers[-1]
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
        if qkv_format == "bshd":
            out = out.view(out.shape[0], -1, *out.shape[-2:])
            ctx.batch_size = out.shape[0]
        elif qkv_format == "sbhd":
            out = out.view(-1, *out.shape[-3:])
            ctx.batch_size = out.shape[1]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
            out = flash_attn_a2a_communicate(
                out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
            )
            if use_fused_attention:
                if qkv_format == "bshd":
                    # [b*s, np, hn] -> [b, s, np, hn]
                    out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                elif qkv_format == "sbhd":
                    # [s*b, np, hn] -> [s, b, np, hn]
                    out = out.view(-1, ctx.batch_size, *out.shape[-2:])
        elif not use_fused_attention:
2431
            out = out.view(-1, *out.shape[-2:])
2432

2433
2434
2435
2436
2437
        if fp8 and use_fused_attention:
            amax_cp_fwd = amax_per_step.amax(dim=1)
            fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0]
            fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]

2438
        out_fp8 = None
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
        out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype)
        if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
            out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)

        if fp8 and fp8_meta["recipe"].fp8_mha:
            out_ret = Float8Tensor(
                data=out_fp8,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_O,
                fp8_dtype=fp8_dtype_forward,
                dtype=q_fp8.dtype,
            )
        else:
            out_ret = out_f16

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            q_save, kv_save, out_save = q, kv, out_fp8
            fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
            fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
        elif fp8 and fp8_meta["recipe"].fp8_mha:
2460
2461
2462
2463
2464
2465
2466
2467
            q_fp8 = Float8Tensor(
                data=q,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_QKV,
                fp8_dtype=fp8_dtype_forward,
                dtype=q_fp8.dtype,
            )
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
            kv_fp8 = Float8Tensor(
                data=kv,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_QKV,
                fp8_dtype=fp8_dtype_forward,
                dtype=k_fp8.dtype,
            )
            q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None
        else:
2479
            q_f16 = q_f16.view(q.shape)
2480
2481
2482
            q_save, kv_save, out_save = q_f16, kv, out_f16
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None

2483
        ctx.save_for_backward(
2484
2485
2486
            q_save,
            kv_save,
            out_save,
2487
            softmax_lse,
2488
2489
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
2490
2491
            fp8_fwd_scales,
            fp8_fwd_scale_invs,
2492
2493
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
2494
2495
            *rng_states,
            *attn_biases,
2496
        )
2497
2498
2499
        ctx.cp_group_a2a = cp_group_a2a
        ctx.cp_size_a2a = cp_size_a2a
        ctx.rank_a2a = rank_a2a
2500
2501
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
2502
        ctx.cp_stream = cp_stream
2503
        ctx.dropout_p = dropout_p
2504
        ctx.total_tokens_kv = total_tokens_kv
2505
        ctx.max_seqlen_q = max_seqlen_q
2506
        ctx.max_seqlen_kv = max_seqlen_kv
2507
        ctx.softmax_scale = softmax_scale
2508
        ctx.qkv_format = qkv_format
2509
        ctx.attn_mask_type = attn_mask_type
2510
2511
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
2512
        ctx.deterministic = deterministic
2513
        ctx.use_fused_attention = use_fused_attention
2514
2515
2516
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
        return out_ret
2517
2518
2519

    @staticmethod
    def backward(ctx, dout):
2520
        # pylint: disable=missing-function-docstring
2521
2522
2523
        cp_size_a2a = ctx.cp_size_a2a
        rank_a2a = ctx.rank_a2a

2524
2525
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
2526
2527
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
2528
2529
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

2530
        (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
2531
2532
2533
2534
2535
        (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8]
        cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size]
        cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
        rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
        attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
2536

2537
2538
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
2539
2540

        seq_dim = None
2541
        if ctx.qkv_format in ["bshd", "sbhd"]:
2542
            seq_dim = ctx.qkv_format.index("s")
2543
2544
2545
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
2546

2547
        if attn_biases[0] is not None:
2548
2549
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
2550
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
2551
2552
2553
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
2554
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
2555
2556
2557
            )
        else:
            attn_dbias = None
2558
            attn_dbias_ = None
2559

2560
2561
2562
2563
        softmax_lse_in_packed_format = not ctx.use_fused_attention and (
            _flash_attn_2_6_0_plus or _use_flash_attn_3
        )

2564
        if causal:
2565
            if ctx.qkv_format == "thd" or softmax_lse_in_packed_format:
2566
                softmax_lse_ = tex.thd_read_second_half_lse(
2567
                    softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format
2568
                )
2569
2570
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
2571
2572
2573
                softmax_lse_ = softmax_lse.view(
                    *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
                )
2574
2575
2576
2577
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
                if ctx.use_fused_attention:
                    # [b, np, sq//2] -> [b, np, sq//2, 1]
                    softmax_lse_.unsqueeze_(-1)
2578
2579
2580
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
2581

2582
        dout_dtype = dout.dtype
2583
2584
2585
2586
2587
        fused_attn_backend = None
        fused_attn_qkv_dtype = None
        fused_attn_dqkv_dtype = None
        amax_per_step = None
        dout_fp8_dtype = None
2588
2589
        if ctx.fp8:
            if ctx.use_fused_attention:
2590
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
2591
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2592
                fused_attn_qkv_dtype = fp8_dtype_forward
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
                fused_attn_dqkv_dtype = fp8_dtype_backward
                fused_attn_backend = FusedAttnBackend["FP8"]
                dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
                dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
                dkv_fp8_ = torch.empty_like(dkv_fp8)
                if ctx.fp8_meta["recipe"].fp8_mha:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
                    dout = dout._data
                else:
                    dout = cast_to_fp8(
                        dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                    )
                p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
                fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
                fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
                fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
                fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
                fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
                fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
                fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP]
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
2621
2622
2623
2624
2625
2626
2627
                q, kv = [x.from_float8(x.dtype) for x in [q, kv]]
                if cp_size_a2a == 1:
                    dout = dout.from_float8(dout_dtype)
                else:
                    dout_fp8_dtype = dout._fp8_dtype
                    dout_scale_inv = dout._scale_inv
                    dout = dout._data
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
            dq = torch.empty_like(q)
            if ctx.qkv_format == "thd" and causal:
                dq[cu_seqlens_q_padded[-1] :].fill_(0)
            p2p_comm_buffers = [
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            ]
            p2p_comm_buffers[0][0].copy_(kv)
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
2639
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
2640
2641
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
        if cp_size_a2a > 1:
            if not ctx.use_fused_attention:
                out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                dout = dout.view(*out.shape)
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True)
            out, dout = flash_attn_a2a_communicate(
                [out, dout],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                True,
            )
            if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
                dout = cast_from_fp8(
2658
2659
2660
2661
2662
2663
                    dout,
                    None,
                    None,
                    dout_fp8_dtype,
                    TE_DType[dout_dtype],
                    scale_inv=dout_scale_inv,  # pylint: disable=used-before-assignment
2664
2665
                )

2666
2667
2668
2669
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        send_recv_reqs = []

2670
        flash_attn_bwd = None
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
                flash_attn_bwd = flash_attn_varlen_bwd_v3
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
                flash_attn_bwd = flash_attn_varlen_bwd
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
2683

2684
2685
2686
2687
2688
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

2689
2690
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
            if ctx.fp8:
                if i < cp_size - 1:
                    send_recv_reqs = flash_attn_p2p_communicate(
                        rank,
                        send_tensor[0],
                        send_dst,
                        recv_tensor[0],
                        recv_src,
                        ctx.cp_group,
                        batch_p2p_comm,
                    )
                else:
                    dkv_a2a_req = torch.distributed.all_to_all_single(
                        dkv_fp8,
                        dkv_fp8_,
                        group=ctx.cp_group,
                        async_op=True,
                    )
                    send_recv_reqs = [dkv_a2a_req]
            else:
                if i == 0:
                    send_tensor = send_tensor[0]
                    recv_tensor = recv_tensor[0]
                if i == (cp_size - 1):
                    send_tensor = send_tensor[1]
                    recv_tensor = recv_tensor[1]
                send_recv_reqs = flash_attn_p2p_communicate(
                    rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
                )
2720

2721
            kv = p2p_comm_buffers[i % 2][0]
2722
            dk_, dv_ = None, None
2723
2724
2725
            if ctx.fp8 and ctx.use_fused_attention:
                fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
                fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
2726
            # In reversed order of fwd
2727
            if causal:
2728
                if i == (cp_size - 1):
2729
                    if ctx.use_fused_attention:
2730
2731
2732
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
2733
2734
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
2735
2736
2737
2738
2739
2740
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
2741
2742
                            # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                            kv_ = kv.view(-1, *kv.shape[-4:])
2743
2744
2745
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
2746
2747
                        elif ctx.qkv_format == "thd":
                            q_, kv_, out_, dout_ = q, kv, out, dout
2748
2749
2750
2751
2752
2753
2754
2755
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2756
                        if attn_dbias is not None:
2757
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2758
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2759
                            ctx.max_seqlen_q,
2760
2761
2762
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2763
                            q_,
2764
2765
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2766
2767
                            out_,
                            dout_,
2768
2769
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2770
                            aux_ctx_tensors,
2771
                            fused_attn_backend,
2772
2773
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2774
2775
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2776
                            qkv_layout=qkv_layout,
2777
                            attn_mask_type=ctx.attn_mask_type,
2778
                            attn_bias_type=ctx.attn_bias_type,
2779
2780
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2781
2782
2783
2784
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
2785
                        dq_ = torch.zeros_like(q_)
2786
2787
2788
2789
2790
2791
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
2792
2793
2794
2795
2796
                        if _use_flash_attn_3 or _flash_attn_2_3_plus:
                            fa_backward_kwargs["window_size"] = (-1, 0)
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
2797
2798
2799
2800
2801
2802
2803
2804
2805
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2806
2807
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2808
                            ctx.max_seqlen_q,
2809
                            ctx.max_seqlen_kv,
2810
2811
                            causal=True,
                            **fa_backward_kwargs,
2812
                        )
2813
                elif i >= (cp_size - rank - 1):
2814
                    if ctx.use_fused_attention:
2815
2816
2817
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
2818
2819
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                            kv_ = kv[:, 0, ...].contiguous()
2820
2821
2822
2823
2824
2825
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
2826
2827
                            # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                            kv_ = kv[0].contiguous()
2828
2829
2830
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
2831
2832
2833
                        elif ctx.qkv_format == "thd":
                            q_, out_, dout_ = q, out, dout
                            # [2, t, np, hn] -> [2, t/2, np, hn]
2834
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2835
2836
2837
2838
2839
2840
2841
2842
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2843
                        if attn_dbias is not None:
2844
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2845
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2846
                            ctx.max_seqlen_q,
2847
2848
2849
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2850
                            q_,
2851
2852
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2853
2854
                            out_,
                            dout_,
2855
2856
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2857
                            aux_ctx_tensors,
2858
                            fused_attn_backend,
2859
2860
2861
2862
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
2863
2864
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2865
                            qkv_layout=qkv_layout,
2866
                            attn_mask_type="padding" if padding else "no_mask",
2867
                            attn_bias_type=ctx.attn_bias_type,
2868
2869
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2870
2871
2872
2873
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
2874
                        dq_ = torch.zeros_like(q_)
2875
2876
                        if ctx.qkv_format == "thd":
                            # [2, t, np, hn] -> [2, t/2, np, hn]
2877
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2878
2879
2880
                        else:
                            # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
2881
2882
2883
2884
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
2885
2886
2887
2888
2889
                        if _use_flash_attn_3 or _flash_attn_2_3_plus:
                            fa_backward_kwargs["window_size"] = (-1, -1)
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
2890
2891
2892
2893
2894
2895
2896
2897
2898
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2899
2900
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2901
                            ctx.max_seqlen_q,
2902
                            ctx.max_seqlen_kv // 2,
2903
2904
                            causal=False,
                            **fa_backward_kwargs,
2905
2906
2907
                        )
                else:
                    if ctx.use_fused_attention:
2908
2909
2910
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous()
2911
2912
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
2913
2914
2915
2916
2917
2918
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous()
                            dout_ = dout[:, 1, ...].contiguous()
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            q_ = q[1].contiguous()
2919
2920
                            # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                            kv_ = kv.view(-1, *kv.shape[-4:])
2921
2922
2923
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            out_ = out[1].contiguous()
                            dout_ = dout[1].contiguous()
2924
2925
                        elif ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
2926
2927
2928
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
2929
                            kv_ = kv
2930
2931
2932
2933
2934
2935
2936
2937
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse_,
                                softmax_lse_,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
2938
                        if attn_dbias is not None:
2939
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2940
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2941
                            ctx.max_seqlen_q // 2,
2942
2943
2944
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2945
                            q_,
2946
2947
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2948
2949
                            out_,
                            dout_,
2950
2951
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2952
                            aux_ctx_tensors,
2953
                            fused_attn_backend,
2954
2955
2956
2957
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2958
2959
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2960
                            qkv_layout=qkv_layout,
2961
                            attn_mask_type="padding" if padding else "no_mask",
2962
                            attn_bias_type=ctx.attn_bias_type,
2963
2964
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2965
2966
                        )
                    else:
2967
2968
                        if ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
2969
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
2970
2971
2972
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
2973
                        dq_ = torch.zeros_like(q_)
2974
2975
2976
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
2977
                        if ctx.qkv_format == "thd":
2978
2979
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
2980
2981
2982
2983
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                            dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
2984
2985
2986
2987
2988
                        if _use_flash_attn_3 or _flash_attn_2_3_plus:
                            fa_backward_kwargs["window_size"] = (-1, -1)
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
2989
2990
2991
2992
2993
2994
2995
2996
2997
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse_,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2998
2999
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
3000
                            ctx.max_seqlen_q // 2,
3001
                            ctx.max_seqlen_kv,
3002
3003
                            causal=False,
                            **fa_backward_kwargs,
3004
3005
3006
                        )
            else:
                if ctx.use_fused_attention:
3007
3008
3009
3010
                    if ctx.fp8:
                        aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
                    else:
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
3011
                    if attn_dbias is not None:
3012
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
3013
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
3014
                        ctx.max_seqlen_q,
3015
3016
3017
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
3018
                        q,
3019
3020
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
3021
3022
                        out,
                        dout,
3023
3024
                        fused_attn_qkv_dtype,
                        fused_attn_dqkv_dtype,
3025
                        aux_ctx_tensors,
3026
                        fused_attn_backend,
3027
3028
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
3029
3030
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
3031
                        qkv_layout=qkv_layout,
3032
                        attn_mask_type=ctx.attn_mask_type,
3033
                        attn_bias_type=ctx.attn_bias_type,
3034
3035
                        deterministic=ctx.deterministic,
                        **fp8_meta_kwargs,
3036
3037
3038
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
3039
                    q_ = q.view(-1, *q.shape[-2:])
3040
                    dq_ = torch.zeros_like(q_)
3041
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
3042
3043
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
3044
                    # [b, sq, np, hn] -> [b*sq, np, hn]
3045
3046
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
3047
3048
3049
3050
3051
                    if _use_flash_attn_3 or _flash_attn_2_3_plus:
                        fa_backward_kwargs["window_size"] = (-1, -1)
                    if not _use_flash_attn_3:
                        fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                    flash_attn_bwd(
3052
3053
3054
3055
3056
3057
3058
3059
3060
                        dout_,
                        q_,
                        kv_[0],
                        kv_[1],
                        out_,
                        softmax_lse,
                        dq_,
                        dkv_[0],
                        dkv_[1],
3061
3062
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
3063
                        ctx.max_seqlen_q,
3064
                        ctx.max_seqlen_kv,
3065
3066
                        causal=False,
                        **fa_backward_kwargs,
3067
3068
                    )

3069
3070
            if ctx.fp8:
                dq = dq_fp8[(rank + i + 1) % cp_size]
3071
            if i >= (cp_size - rank - 1) or not causal:
3072
3073
3074
3075
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
3076
3077
3078
3079
3080
3081
                if ctx.qkv_format == "bshd":
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
                elif ctx.qkv_format == "sbhd":
                    # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
                    dq_ = dq_.view(-1, *dq.shape[-3:])
3082

3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
            if ctx.fp8:
                if i >= (cp_size - rank - 1) or not causal:
                    dq.copy_(dq_)
                else:
                    if ctx.qkv_format == "bshd":
                        dq[:, 0, ...].fill_(0)
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[0].fill_(0)
                        dq[1].copy_(dq_)
            elif causal:
3094
                if i > (cp_size - rank - 1):
3095
                    dq.add_(dq_)
3096
3097
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
3098
3099
                        dq.copy_(dq_)
                    else:
3100
3101
3102
3103
3104
3105
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
3106
                        elif ctx.qkv_format == "thd":
3107
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
3108
                elif i > 0:
3109
3110
3111
3112
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
3113
                    elif ctx.qkv_format == "thd":
3114
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
3115
                else:
3116
3117
3118
3119
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
3120
                    elif ctx.qkv_format == "thd":
3121
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
3122
3123
3124
3125
3126
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
3127

3128
            if attn_dbias is not None:
3129
                idx = (rank + i + 1) % cp_size
3130
                if i == (cp_size - 1) or not causal:
3131
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
3132
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
3133
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
3134
3135
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
3136
3137
3138
3139
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
3140
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
3141
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
3142
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
3143

3144
3145
3146
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
3147

3148
3149
3150
3151
3152
3153
3154
            if ctx.fp8:
                if i < cp_size - 1:
                    dkv = dkv_fp8_[(rank + i + 1) % cp_size]
                else:
                    dkv = dkv_fp8[(rank + i + 1) % cp_size]
            else:
                dkv = p2p_comm_buffers[(i + 1) % 2][1]
3155
            if ctx.use_fused_attention:
3156
3157
3158
                dkv_ = torch.cat(
                    (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
                )  # pylint: disable=used-before-assignment
3159
3160
3161
3162
                if ctx.qkv_format in ["bshd", "sbhd"]:
                    # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
3163
            if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
3164
3165
3166
3167
3168
3169
                if ctx.qkv_format == "bshd":
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                elif ctx.qkv_format == "sbhd":
                    # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
3170
3171
3172
3173
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
3174

3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
            if ctx.fp8:
                if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
                    if ctx.qkv_format == "bshd":
                        dkv[:, :, 0, ...].copy_(dkv_)
                        dkv[:, :, 1, ...].fill_(0)
                    elif ctx.qkv_format == "sbhd":
                        dkv[:, 0, ...].copy_(dkv_)
                        dkv[:, 1, ...].fill_(0)
                else:
                    dkv.copy_(dkv_)
            elif causal:
3186
                if i == (cp_size - 1):
3187
                    if rank == 0:
3188
3189
3190
3191
3192
3193
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
3194
                        elif ctx.qkv_format == "thd":
3195
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
3196
3197
                    else:
                        dkv.add_(dkv_)
3198
3199
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
3200
3201
3202
3203
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
3204
                        elif ctx.qkv_format == "thd":
3205
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
3206
                    else:
3207
3208
3209
3210
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
3211
                        elif ctx.qkv_format == "thd":
3212
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
3213
3214
3215
3216
3217
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
3218
3219
3220
3221
3222
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
        if ctx.fp8 and ctx.use_fused_attention:
            amax_cp_bwd = amax_per_step.amax(dim=1)
            ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0]
            ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1]
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
                # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
                dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
            dq, dkv = [
                cast_from_fp8(
                    x,
                    ctx.fp8_meta["scaling_bwd"],
                    META_DQKV_CP,
                    fp8_dtype_backward,
                    TE_DType[torch.float32],
                )
                for x in [dq_fp8, dkv_fp8]
            ]
            dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]

3243
        if causal:
3244
3245
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
3246
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
3247
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
3248
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
3249
3250
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
3251
                dq = dq.view(-1, *dq.shape[-3:])
3252
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
3253
3254
3255
3256
3257
3258
3259
3260
3261
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

        if ctx.qkv_format == "thd":
            dkv_ = torch.empty(
                2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device
            )
            dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv)
            dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
            dkv = dkv_
3262

3263
3264
3265
3266
3267
        if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
            dq, dkv = [
                cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
                for x in [dq, dkv]
            ]
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
        dk, dv = dkv[0], dkv[1]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False)
            dq, dk, dv = flash_attn_a2a_communicate(
                [dq, dk, dv],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                False,
            )
            if ctx.qkv_format == "bshd":
                dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
            elif ctx.qkv_format == "sbhd":
                dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
3287
3288
3289
3290
3291
3292
3293
3294
3295
            dq, dk, dv = [
                Float8Tensor(
                    data=x,
                    fp8_meta=ctx.fp8_meta,
                    fp8_meta_forward=False,
                    fp8_meta_index=META_DQKV,
                    fp8_dtype=fp8_dtype_backward,
                    dtype=dout_dtype,
                )
3296
                for x in [dq, dk, dv]
3297
3298
            ]

3299
3300
3301
3302
        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

3303
3304
3305
        return (
            None,
            dq,
3306
3307
            dk,
            dv,
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3319
            attn_dbias,
3320
3321
3322
3323
3324
            None,
            None,
            None,
            None,
            None,
3325
3326
            None,
            None,
3327
        )
3328
3329


3330
3331
def get_kv_seq_info_after_all_gather(
    local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
3332
):
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
    """Compute KV sequence index range and update window size after all-gather."""
    local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
    full_seq_end_idx = max_seqlen_kv * cp_size * 2

    if window_size is None:
        window_size = (-1, 0) if causal else (-1, -1)

    if window_size[1] == -1:
        seq_end_idx = full_seq_end_idx
        window_size_right = -1
    else:
        seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
        window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx

    if window_size[0] == -1:
        seq_start_idx = 0
        window_size_left = -1
    else:
        seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
        window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx

    return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
3355
3356
3357
3358


class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    """
3359
3360
    Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
    Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
3383
3384
        cp_group,
        cp_stream,
3385
    ):
3386
        # pylint: disable=missing-function-docstring
3387
3388
3389
3390
3391
3392
3393
3394
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
3395
        assert not padding, f"{attn_mask_type} mask type is not supported!"
3396
3397
3398
3399
3400
3401
3402
        if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
            attn_mask_type = attn_mask_type + "_bottom_right"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            use_fused_attention or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
3403

3404
        flash_attn_fwd = None
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
                flash_attn_fwd = flash_attn_varlen_fwd_v3
            else:
                flash_attn_fwd = flash_attn_varlen_fwd
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_5_7_plus:
                    fa_forward_kwargs["block_table"] = None
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        max_seqlen_q = max_seqlen_q // (2 * cp_size)
        max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
        cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
        cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)

3431
3432
3433
3434
        # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
        q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
        # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
        k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
3435

3436
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3437
3438
        k_ag, _ = gather_along_first_dim(k, cp_group)
        v_ag, _ = gather_along_first_dim(v, cp_group)
3439
3440

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3441
3442
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        cp_stream.wait_stream(torch.cuda.current_stream())

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
3453
3454

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
3455
3456
3457
        kv_seq_range_per_step = [None, None]
        window_size_per_step = [None, None]
        cu_seqlens_kv_per_step = [None, None]
3458
3459
3460
3461
3462
3463
3464
3465
        out_per_step = [None, None]
        softmax_lse_per_step = [None, None]
        rng_states = [None, None]
        out = torch.empty_like(q)

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3466
3467
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3468
3469
3470
3471
3472
3473
3474
3475
3476
                    q_ = q.select(seq_dim, i).contiguous()
                    kv_seq_range_per_step[i], window_size_per_step[i] = (
                        get_kv_seq_info_after_all_gather(
                            local_seq_chunk_ids[i],
                            cp_size,
                            max_seqlen_q,
                            max_seqlen_kv,
                            window_size,
                            causal,
3477
                        )
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
                    )
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv_ = seq_end_idx - seq_start_idx
                    cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
                        k.shape[1], max_seqlen_kv_, k.device
                    )
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
3490
3491
3492
3493
                    if use_fused_attention:
                        out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
                            is_training,
                            max_seqlen_q,
3494
                            max_seqlen_kv_,
3495
                            cu_seqlens_q,
3496
                            cu_seqlens_kv_per_step[i],
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
                            q_,
                            k_,
                            v_,
                            TE_DType[q.dtype],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=softmax_scale,
                            dropout=dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=attn_mask_type,
                            attn_bias_type=attn_bias_type,
                            attn_bias=attn_bias,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
3509
3510
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
                            window_size=window_size_per_step[i],
3511
3512
3513
                        )
                    else:
                        q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
                        fa_outputs = flash_attn_fwd(
                            q_,
                            k_,
                            v_,
                            cu_seqlens_q,
                            cu_seqlens_kv_per_step[i],
                            max_seqlen_q,
                            max_seqlen_kv_,
                            causal=causal,
                            window_size=window_size_per_step[i],
                            **fa_forward_kwargs,
3525
                        )
3526
3527
3528
3529
                        out_per_step[i] = fa_outputs[4]
                        softmax_lse_per_step[i] = fa_outputs[5]
                        if not _use_flash_attn_3:
                            rng_states[i] = fa_outputs[7]
3530
3531
3532
3533

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if qkv_format == "bshd":
3534
                        out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape))
3535
                    elif qkv_format == "sbhd":
3536
                        out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape))
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553

        torch.cuda.current_stream().wait_stream(cp_stream)

        if use_fused_attention:
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
        else:
            out = out.view(-1, *out.shape[-2:])

        ctx.save_for_backward(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_q_padded,
3554
            *cu_seqlens_kv_per_step,
3555
3556
3557
3558
            *out_per_step,
            *softmax_lse_per_step,
            *rng_states,
        )
3559
3560
        ctx.kv_seq_range_per_step = kv_seq_range_per_step
        ctx.window_size_per_step = window_size_per_step
3561
3562
3563
3564
3565
3566
3567
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_bias_type = attn_bias_type
3568
        ctx.attn_mask_type = attn_mask_type
3569
3570
3571
3572
3573
3574
        ctx.deterministic = deterministic
        ctx.use_fused_attention = use_fused_attention
        return out

    @staticmethod
    def backward(ctx, dout):
3575
        # pylint: disable=missing-function-docstring
3576
3577
3578
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)

3579
3580
3581
3582
3583
3584
3585
        (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5]
        cu_seqlens_kv_per_step = ctx.saved_tensors[5:7]
        out_per_step = ctx.saved_tensors[7:9]
        softmax_lse_per_step = ctx.saved_tensors[9:11]
        rng_states = ctx.saved_tensors[11:13]
        kv_seq_range_per_step = ctx.kv_seq_range_per_step
        window_size_per_step = ctx.window_size_per_step
3586

3587
        seq_dim = ctx.qkv_format.index("s")
3588
3589
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

3590
        dout = dout.view(q.shape)
3591
        dq = torch.empty_like(q)
3592
        dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
        dv = torch.zeros_like(dk)
        dq_per_step = [None, None]
        dk_per_step = [None, None]
        dv_per_step = [None, None]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
        # synchronize dkv update across steps
        dkv_update_done = torch.cuda.Event()

3603
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3604
3605
        k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
        v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
3606
3607

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3608
3609
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3610
3611
3612
3613
3614
3615
3616
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        ctx.cp_stream.wait_stream(torch.cuda.current_stream())
3617
3618
3619

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]

3620
        flash_attn_bwd = None
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
                flash_attn_bwd = flash_attn_varlen_bwd_v3
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
                flash_attn_bwd = flash_attn_varlen_bwd
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
3633
3634
3635
3636

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3637
3638
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3639
3640
3641
3642
3643
3644
3645
3646
3647
                    q_ = q.select(seq_dim, i).contiguous()
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv = seq_end_idx - seq_start_idx
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
3648
                    out_ = out_per_step[i]
3649
                    dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
3650
3651
3652
3653
                    if ctx.use_fused_attention:
                        aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
                            ctx.max_seqlen_q,
3654
                            max_seqlen_kv,
3655
                            cu_seqlens_q,
3656
                            cu_seqlens_kv_per_step[i],
3657
3658
3659
3660
3661
3662
                            q_,
                            k_,
                            v_,
                            out_,
                            dout_,
                            TE_DType[q.dtype],
3663
                            TE_DType[dout.dtype],
3664
3665
3666
                            aux_ctx_tensors,
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
3667
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
3668
3669
3670
3671
3672
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=ctx.attn_mask_type,
                            attn_bias_type=ctx.attn_bias_type,
3673
3674
                            window_size=window_size_per_step[i],
                            deterministic=ctx.deterministic,
3675
3676
                        )
                    else:
3677
                        batch_size = k_.shape[0]
3678
3679
3680
3681
                        q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
3682
3683
3684
                        if not _use_flash_attn_3:
                            fa_backward_kwargs["rng_state"] = rng_states[i]
                        flash_attn_bwd(
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
                            dout_,
                            q_,
                            k_,
                            v_,
                            out_,
                            softmax_lse_per_step[i],
                            dq_per_step[i],
                            dk_per_step[i],
                            dv_per_step[i],
                            cu_seqlens_q,
3695
                            cu_seqlens_kv_per_step[i],
3696
                            ctx.max_seqlen_q,
3697
                            max_seqlen_kv,
3698
                            causal="causal" in ctx.attn_mask_type,
3699
                            window_size=window_size_per_step[i],
3700
                            **fa_backward_kwargs,
3701
                        )
3702
3703
3704
3705
3706
3707
3708
                        # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                        dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape)
                        # [b*s_range, np, hn] -> [b, s_range, np, hn]
                        dk_per_step[i], dv_per_step[i] = [
                            x.view(batch_size, -1, *x.shape[-2:])
                            for x in [dk_per_step[i], dv_per_step[i]]
                        ]
3709
3710
3711
3712

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if ctx.qkv_format == "bshd":
3713
                        dq[:, i - 1].copy_(dq_per_step[i - 1])
3714
                    elif ctx.qkv_format == "sbhd":
3715
3716
3717
3718
3719
3720
                        dq[i - 1].copy_(dq_per_step[i - 1])
                    # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
                    dk_per_step[i - 1], dv_per_step[i - 1] = [
                        x.movedim(seq_dim, 0).contiguous()
                        for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
                    ]
3721
3722
3723
                    # wait until dkv update of last step is done
                    if i > 1:
                        flash_attn_streams[i - 1].wait_event(dkv_update_done)
3724
3725
3726
3727
3728
3729
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i - 1][0],
                        kv_seq_range_per_step[i - 1][1],
                    )
                    dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
                    dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
3730
3731
3732
3733
3734
                    if i < len(local_seq_chunk_ids):
                        flash_attn_streams[i - 1].record_event(dkv_update_done)

        torch.cuda.current_stream().wait_stream(ctx.cp_stream)

3735
3736
3737
3738
3739
3740
3741
        # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
        dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
        dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False)
        dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
        dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
3742
3743
3744
3745
3746
        dk = dk.view(-1, *dk.shape[-3:])
        dv = dv.view(-1, *dv.shape[-3:])
        dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
        dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)

3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
        dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
        dk = dk.movedim(0, seq_dim).contiguous()
        dv = dv.movedim(0, seq_dim).contiguous()

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
    """
    Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
    Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
        fp8,
        fp8_meta,
        cp_group,
        cp_stream,
    ):
3807
        # pylint: disable=missing-function-docstring
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
        assert not padding, f"{attn_mask_type} mask type is not supported!"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            window_size == (-1, 0)
            or window_size == (-1, -1)
            or use_fused_attention
            or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
3824

3825
        flash_attn_fwd = None
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
            if _use_flash_attn_3:
                flash_attn_fwd = flash_attn_varlen_fwd_v3
                fa_forward_kwargs["window_size"] = window_size
            else:
                flash_attn_fwd = flash_attn_varlen_fwd
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
                if _flash_attn_2_3_plus:
                    fa_forward_kwargs["window_size"] = window_size
                if _flash_attn_2_4_plus:
                    fa_forward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_5_7_plus:
                    fa_forward_kwargs["block_table"] = None
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854

        assert (
            q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
        ), "The number of attention heads needs to be divisible by CP size!"

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        batch_dim = qkv_format.index("b")
        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

3855
3856
        fused_attn_backend = None
        fused_attn_qkv_dtype = None
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881
3882
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
3910
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924
3925
3926
3927
3928
3929
3930
3931
3932
3933
3934
3935
3936
        if fp8:
            if use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_backend = FusedAttnBackend["FP8"]
                if fp8_meta["recipe"].fp8_mha:
                    assert (
                        isinstance(q, Float8Tensor)
                        and isinstance(k, Float8Tensor)
                        and isinstance(v, Float8Tensor)
                    ), "q/k/v must be Float8Tensors for FP8 MHA!"
                    fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                    q_f16, k_f16, v_f16 = q, k, v
                    q, k, v = [
                        cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                        for x in [q_f16, k_f16, v_f16]
                    ]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
                fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_o_offset"] = META_O
                fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history
                fp8_meta_kwargs["amax_s_offset"] = META_S
                fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history
                fp8_meta_kwargs["amax_o_offset"] = META_O
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True)
        q, k, v = flash_attn_a2a_communicate(
            [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
        )

        if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            q_f16, k_f16, v_f16 = q, k, v
            q, k, v = [
                cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                for x in [q_f16, k_f16, v_f16]
            ]

        batch_size = q.shape[batch_dim]
        if use_fused_attention:
            out, aux_ctx_tensors = fused_attn_fwd(
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                fused_attn_qkv_dtype,
                fused_attn_backend,
                attn_scale=softmax_scale,
                dropout=dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                attn_bias=attn_bias,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                window_size=window_size,
                **fp8_meta_kwargs,
            )
        else:
            # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn]
            q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]]
3937
            fa_outputs = flash_attn_fwd(
3938
3939
3940
3941
3942
3943
3944
3945
                q,
                k,
                v,
                cu_seqlens_q,
                cu_seqlens_kv,
                max_seqlen_q,
                max_seqlen_kv,
                causal=causal,
3946
                **fa_forward_kwargs,
3947
            )
3948
3949
            out, softmax_lse = fa_outputs[4], fa_outputs[5]
            rng_state = fa_outputs[7] if not _use_flash_attn_3 else None
3950
3951
3952
3953
3954
3955
3956
3957
3958
3959
3960
3961
3962
3963
3964
3965
3966
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988
3989
3990
3991
3992
3993
3994
3995
3996
3997
3998
3999
4000
4001
4002
4003
4004
4005
4006
4007
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
4020
4021
4022
4023
4024
4025
4026
4027
4028
4029
4030
4031
4032
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
4047
4048
4049
            aux_ctx_tensors = [softmax_lse, rng_state]
            # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn]
            out = out.view(batch_size, -1, *out.shape[-2:])

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False)
        out = flash_attn_a2a_communicate(
            out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
        )

        if use_fused_attention:
            if qkv_format == "bshd":
                # [b*s, np, hn] -> [b, s, np, hn]
                out = out.view(batch_size, -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                # [s*b, np, hn] -> [s, b, np, hn]
                out = out.view(-1, batch_size, *out.shape[-2:])

        if fp8:
            if fp8_meta["recipe"].fp8_mha:
                out_fp8 = Float8Tensor(
                    data=out,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q_fp8.dtype,
                )
                out = out_fp8._data
                out_ret = out_fp8
            else:
                out_f16 = cast_from_fp8(
                    out,
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    TE_DType[q_f16.dtype],
                )
                out_ret = out_f16
        else:
            out_ret = out

        if fp8:
            if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                q_save, k_save, v_save, out_save = q, k, v, out
            elif fp8_meta["recipe"].fp8_mha:
                q_fp8, k_fp8, v_fp8 = [
                    Float8Tensor(
                        data=x,
                        fp8_meta=fp8_meta,
                        fp8_meta_forward=True,
                        fp8_meta_index=META_QKV,
                        fp8_dtype=fp8_dtype_forward,
                        dtype=out_fp8.dtype,
                    )
                    for x in [q, k, v]
                ]
                q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8
            else:
                q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16
        else:
            q_save, k_save, v_save, out_save = q, k, v, out

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
            fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
        else:
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None

        ctx.save_for_backward(
            q_save,
            k_save,
            v_save,
            out_save,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            fp8_fwd_scales,
            fp8_fwd_scale_invs,
            *aux_ctx_tensors,
        )
        ctx.batch_size = batch_size
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_mask_type = attn_mask_type
        ctx.attn_bias_type = attn_bias_type
        ctx.deterministic = deterministic
        ctx.window_size = window_size
        ctx.use_fused_attention = use_fused_attention
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
        return out_ret

    @staticmethod
    def backward(ctx, dout):
4050
        # pylint: disable=missing-function-docstring
4051
4052
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
4063
        cp_size = get_distributed_world_size(ctx.cp_group)

        q, k, v, out = ctx.saved_tensors[:4]
        cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[
            4:8
        ]
        fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10]
        aux_ctx_tensors = ctx.saved_tensors[10:]

        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
        causal = "causal" in ctx.attn_mask_type
        seq_dim = ctx.qkv_format.index("s")

4064
4065
4066
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
        fused_attn_qkv_dtype = None
4067
4068
4069
4070
4071
4072
4073
4074
4075
4076
4077
4078
4079
4080
4081
4082
4083
4084
4085
4086
4087
4088
4089
4090
4091
4092
4093
4094
4095
4096
4097
4098
4099
4100
4101
4102
4103
4104
4105
4106
4107
4108
4109
4110
4111
4112
4113
4114
4115
4116
4117
        if ctx.fp8:
            if ctx.use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_dqkv_dtype = fp8_dtype_backward
                fused_attn_backend = FusedAttnBackend["FP8"]
                if ctx.fp8_meta["recipe"].fp8_mha:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
                    dout_fp8 = dout
                    dout = dout_fp8._data
                else:
                    dout_f16 = dout
                    dout = cast_to_fp8(
                        dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                    )
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
                fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
                fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
                fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
                fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
                fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
                fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
                fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV]
                fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP]
                fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][
                    META_DQKV
                ]
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
                assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]]
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_dqkv_dtype = TE_DType[dout.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if not ctx.use_fused_attention:
            out = out.view(ctx.batch_size, -1, *out.shape[-2:])
        dout = dout.view(*out.shape)

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True)
        out, dout = flash_attn_a2a_communicate(
            [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
        )

4118
        flash_attn_bwd = None
4119
4120
4121
4122
4123
4124
4125
4126
4127
4128
4129
4130
4131
4132
4133
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
            if _use_flash_attn_3:
                flash_attn_bwd = flash_attn_varlen_bwd_v3
                fa_backward_kwargs["window_size"] = ctx.window_size
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
                flash_attn_bwd = flash_attn_varlen_bwd
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
                if _flash_attn_2_3_plus:
                    fa_backward_kwargs["window_size"] = ctx.window_size
                if _flash_attn_2_4_plus:
                    fa_backward_kwargs["alibi_slopes"] = None
                if _flash_attn_2_4_1_plus:
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
4134
4135
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
4156
4157
4158
4159
4160
4161
4162
4163
4164

        if ctx.use_fused_attention:
            dq, dk, dv, _ = fused_attn_bwd(
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                out,
                dout,
                fused_attn_qkv_dtype,
                fused_attn_dqkv_dtype,
                aux_ctx_tensors,
                fused_attn_backend,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                attn_scale=ctx.softmax_scale,
                dropout=ctx.dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=ctx.attn_mask_type,
                attn_bias_type=ctx.attn_bias_type,
                window_size=ctx.window_size,
                deterministic=ctx.deterministic,
                **fp8_meta_kwargs,
            )
        else:
            softmax_lse, rng_state = aux_ctx_tensors
            out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]]
            dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
4165
4166
4167
            if not _use_flash_attn_3:
                fa_backward_kwargs["rng_state"] = rng_state
            flash_attn_bwd(
4168
4169
4170
4171
4172
4173
4174
4175
4176
4177
4178
4179
4180
                dout,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
4181
4182
                causal=causal,
                **fa_backward_kwargs,
4183
4184
4185
4186
4187
4188
4189
4190
            )
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False)
        dq, dk, dv = flash_attn_a2a_communicate(
            [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
        )

4191
        if ctx.qkv_format == "bshd":
4192
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
4193
        elif ctx.qkv_format == "sbhd":
4194
4195
4196
4197
4198
4199
4200
4201
4202
4203
4204
4205
4206
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
            dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8:
            if ctx.fp8_meta["recipe"].fp8_mha:
                dq, dk, dv = [
                    Float8Tensor(
                        data=x,
                        fp8_meta=ctx.fp8_meta,
                        fp8_meta_forward=False,
                        fp8_meta_index=META_DQKV,
                        fp8_dtype=fp8_dtype_backward,
                        dtype=dout_fp8.dtype,
                    )
                    for x in [dq, dk, dv]
                ]
            else:
                dq, dk, dv = [
                    cast_from_fp8(
                        x,
                        ctx.fp8_meta["scaling_bwd"],
                        META_DQKV,
                        fp8_dtype_backward,
                        TE_DType[dout_f16.dtype],
                    )
                    for x in [dq, dk, dv]
                ]
4220
4221
4222
4223
4224
4225
4226
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
4241
4242

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4243
4244
4245
            None,
            None,
            None,
4246
4247
4248
        )


4249
def attn_forward_func_with_cp(
4250
4251
4252
4253
4254
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
4255
    cu_seqlens_kv,
4256
    max_seqlen_q,
4257
    max_seqlen_kv,
4258
4259
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
4260
4261
4262
4263
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
4264
    cp_comm_type,
4265
4266
4267
4268
4269
4270
4271
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
4272
    window_size=None,
4273
4274
    fp8=False,
    fp8_meta=None,
4275
) -> torch.Tensor:
4276
4277
4278
4279
    """
    Attention implementation with context parallelism.
    """

4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
    if cp_comm_type == "a2a+p2p":
        assert isinstance(
            cp_group, list
        ), "Hierarchical CP implementation needs multi-level CP groups!"
        assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
        if get_distributed_world_size(cp_group[0]) == 1:
            cp_group = cp_group[1]
            cp_comm_type = "p2p"
        elif get_distributed_world_size(cp_group[1]) == 1:
            cp_group = cp_group[0]
            cp_comm_type = "a2a"
    else:
        assert isinstance(
            cp_group, dist_group_type
        ), f"Unsupported process group for CP communication type {cp_comm_type}!"

4296
4297
4298
4299
4300
4301
4302
4303
4304
4305
4306
4307
4308
4309
4310
4311
4312
4313
4314
4315
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert (
        qkv_format != "thd"
        or not use_fused_attention
        or attn_mask_type in ["padding", "padding_causal"]
    ), (
        f"Context parallelism is not supported for {attn_mask_type} mask type and "
        f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
    )
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
4316
4317
4318
    assert (
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
    ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
4319
4320
4321

    sliding_window_attn = (
        window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
4322
    )
4323
4324
4325
4326
4327
    assert (
        not sliding_window_attn
        or cp_comm_type == "a2a"
        or (cp_comm_type == "all_gather" and not use_fused_attention)
    ), "The context parallel running configs cannot support sliding window attetnion!"
4328

4329
4330
4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
4345
4346
4347
4348
4349
    args = [
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ]

4350
    if cp_comm_type in ["p2p", "a2a+p2p"]:
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
        args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream]
        out = AttnFuncWithCPAndKVP2P.apply(*args)
    elif cp_comm_type == "all_gather":
        args.pop(5)
        args.pop(8)
        args += [window_size, cp_group, cp_stream]
        out = AttnFuncWithCPAndKVAllGather.apply(*args)
    elif cp_comm_type == "a2a":
        args += [window_size, fp8, fp8_meta, cp_group, cp_stream]
        out = AttnFuncWithCPAndQKVOA2A.apply(*args)
4361
4362
4363
    else:
        raise ValueError(f"Unsupported communication type: {cp_comm_type}!")

4364
4365
4366
    return out


4367
4368
4369
4370
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
4371

4372
4373
4374
    def __init__(
        self,
        dim: int,
4375
        rotary_percent: float = 1.0,
4376
4377
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
4378
        rotary_base: float = 10000.0,
4379
4380
4381
4382
4383
4384
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
4385
4386
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
4387
4388
4389
4390
4391
4392
4393
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
4394
4395
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
4396
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
4397
        self.rotary_base = rotary_base
4398
        inv_freq = 1.0 / (
4399
            self.rotary_base
4400
4401
4402
4403
4404
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
4405
        self.register_buffer("inv_freq", inv_freq)
4406
4407
4408
4409
4410
4411
4412
4413
4414
4415
4416
4417
4418
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
4419
4420
4421
4422
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
4423

4424
4425
4426
4427
4428
4429
4430
4431
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
4432
4433
4434
4435
4436
4437
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

4438
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
4439
4440
4441
4442
4443
4444
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

4445
4446
4447
4448
4449
4450
4451
4452
4453
4454
4455
4456
4457
4458
4459
4460
4461

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
4462
4463
        cp_size: int = 1,
        cp_rank: int = 0,
4464
    ) -> torch.Tensor:
4465
        # pylint: disable=missing-function-docstring
4466
4467
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
4468
4469
4470
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
4471
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
4472
        elif tensor_format == "thd":
4473
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
4474
4475
4476
4477
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format
4478
4479
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
4480
4481
4482
4483

        return output

    @staticmethod
4484
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
4485
        # pylint: disable=missing-function-docstring
4486
4487
4488
4489
4490
4491
4492
4493
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
4494
4495
4496
            grad_input = tex.fused_rope_thd_backward(
                grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
            )
4497
4498
4499
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

4500
        return grad_input, None, None, None, None, None
4501
4502


4503
4504
4505
4506
4507
4508
4509
4510
4511
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


4512
def apply_rotary_pos_emb(
4513
4514
4515
4516
4517
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
4518
4519
    cp_size: int = 1,
    cp_rank: int = 0,
4520
) -> torch.Tensor:
4521
    """
4522
    Apply rotary positional embedding tensor to the input tensor.
4523

4524
4525
4526
    Parameters
    ----------
    t: torch.Tensor
4527
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
4528
4529
4530
4531
4532
4533
4534
4535
4536
4537
4538
4539
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
4540
4541
4542
4543
4544
        Should be `cu_seqlens_padded` when cp_size > 1.
    cp_size: int, default = 1.
        Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
    cp_rank: int, default = 0.
        Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
4545
    """
4546
4547
4548
4549
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
4550
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
4551
4552
4553
4554
4555
4556

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

4557
4558
4559
4560
4561
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
4562
4563
4564
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
4565
    freqs = freqs[:cur_seq_len]
4566
    if tensor_format == "bshd":
4567
4568
4569
4570
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
4571

4572
4573
4574
4575
4576
4577
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
4578
    t = (t * cos_) + (_rotate_half(t) * sin_)
4579
4580
4581
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
4582
class _SplitAlongDim(torch.autograd.Function):
4583
4584
4585
    """"""

    @staticmethod
4586
4587
4588
4589
4590
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
4591
    ) -> Tuple[torch.Tensor, ...]:
4592
        # pylint: disable=missing-function-docstring
cyanguwa's avatar
cyanguwa committed
4593
4594
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
4595
        if isinstance(mixed_x_layer, Float8Tensor):
4596
4597
4598
4599
4600
4601
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
                    data=x,
                )
                for x in torch.split(
4602
4603
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
4604
4605
4606
4607
                    dim=split_dim,
                )
            )
        return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
4608
4609

    @staticmethod
4610
    def backward(ctx, *grad_outputs):
4611
        # pylint: disable=missing-function-docstring
4612
4613
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
4614
4615
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
4616
4617
4618
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
4619
4620
4621
4622
4623
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

4624
4625
4626
4627
4628
4629
4630
4631
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
4632
4633
4634
4635
4636
4637
4638
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
4639
4640
4641
                    noop_ok = False
                    break
            if noop_ok:
4642
4643
4644
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
4645
4646
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
4647
4648
4649
4650
4651
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
4652
4653
4654
4655
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
4656
4657
4658
4659
4660
4661
4662
            return (
                Float8Tensor.make_like(
                    grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim)
                ),
                None,
                None,
            )
4663
4664
        noop_ok = True
        strides = grad_outputs[0].stride()
4665
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
4666
        shape = list(grad_outputs[0].shape)
4667
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
4668
4669
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
4670
4671
4672
4673
4674
4675
4676
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
4677
4678
4679
                noop_ok = False
                break
        if noop_ok:
4680
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
4681
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
4682
            new_shape[split_dim] = sum(split_sizes)
4683
4684
4685
4686
4687
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
4688
            )
cyanguwa's avatar
cyanguwa committed
4689
            return ret, None, None
4690

4691
        return torch.cat(grad_outputs, dim=split_dim), None, None
4692
4693
4694
4695
4696
4697
4698
4699
4700


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
4701
        softmax_scale: float,
4702
        attention_type: str = "self",
4703
4704
4705
4706
4707
4708
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

4709
        self.softmax_scale = softmax_scale
4710
        self.attention_type = attention_type
4711
4712
4713
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

4714
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
4715
4716
4717
4718
4719
4720

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

4721
4722
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
4723
4724
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
4725

4726
4727
4728
4729
4730
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4731
        qkv_layout: str = "sbh3d",
4732
4733
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
4734
        attn_mask_type: str = "causal",
4735
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4736
4737
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
4738
        alibi_slopes: Optional[torch.Tensor] = None,
4739
    ) -> torch.Tensor:
4740
        """Unfused attention fprop"""
4741
4742
4743
4744
4745
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
4746
            # convert to sbhd and use sbhd implementation for now
4747
4748
4749
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
4750
4751
4752
4753
4754
4755
4756
4757
4758
4759
4760
4761
4762
4763
4764
4765
4766
4767
4768
4769
4770
4771
4772
4773
4774
4775
4776
4777
4778
4779
4780
4781
4782
4783
4784
4785
4786
4787
4788
4789
4790
4791
4792
4793
4794
4795
4796
4797
4798
4799
4800
4801
        batch_size, max_seqlen_q, max_seqlen_kv = (
            query_layer.shape[1],
            query_layer.shape[0],
            key_layer.shape[0],
        )
        if "padding" in attn_mask_type:
            if self.attention_type == "self":
                assert attention_mask.shape == (
                    batch_size,
                    1,
                    1,
                    max_seqlen_q,
                ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
                attention_mask = torch.logical_or(
                    attention_mask.squeeze(1).unsqueeze(3), attention_mask
                )
            else:
                assert (
                    len(attention_mask) == 2
                    and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
                    and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
                ), (
                    "attention_mask should be a tuple of two tensors with shapes "
                    "[b, 1, 1, sq] and [b, 1, 1, skv]!"
                )
                attention_mask = torch.logical_or(
                    attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
                )
            mask = attention_mask.squeeze(1).logical_not()
            actual_seqlens_q = mask[:, :, 0].sum(dim=1)
            actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
            mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
                1, 1, max_seqlen_q, 1
            ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
                1, 1, 1, max_seqlen_kv
            )
            if attn_mask_type == "padding_causal":
                attention_mask = torch.logical_or(
                    torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
                    attention_mask,
                )
            if attn_mask_type == "padding_causal_bottom_right":
                attention_mask = torch.logical_or(
                    torch.where(
                        mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
                        + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
                        < 0,
                        1,
                        0,
                    ),
                    attention_mask,
                )
4802

4803
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
4804
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
4805
4806
4807
4808
4809
4810
4811
4812
4813

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

4814
        if key_layer.shape[2] != query_layer.shape[2]:
4815
4816
4817
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
4818
            key_layer = key_layer.repeat_interleave(
4819
4820
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
4821
            value_layer = value_layer.repeat_interleave(
4822
4823
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
4824

4825
        # [sq, b, np, hn] -> [sq, b * np, hn]
4826
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
4827
4828
4829
4830
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
4831
4832
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
4833
4834
4835
4836
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
4837
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
4838
4839
4840
            device=torch.cuda.current_device(),
        )

4841
4842
4843
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

4844
        scale = self.softmax_scale
4845
        if apply_qk_layer_scaling:
4846
            scale /= self.layer_number
4847
4848

        # Raw attention scores. [b * np, sq, sk]
4849
4850
4851
4852
4853
4854
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
4855
                alpha=scale,
4856
            ).view(*output_size)
4857
4858
4859
4860
4861
4862
4863

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
4864
            matmul_result = matmul_result.view(*output_size) + core_attention_bias
4865
            matmul_result *= scale
4866

4867
4868
4869
4870
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
4871
                _, core_attention_bias = get_alibi(
4872
4873
4874
                    output_size[1],
                    output_size[2],
                    output_size[3],
4875
4876
                    actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
                    actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
4877
4878
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
4879
                )
4880
4881
4882
4883
4884
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
4885
                alpha=scale,
4886
            )
4887
4888
            matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
                dtype=query_layer.dtype
4889
            )
4890
4891
4892

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
4893
        attention_probs = self.scale_mask_softmax(
4894
            matmul_result, attention_mask, attn_mask_type, softmax_scale
4895
        )
4896

4897
4898
4899
4900
4901
        # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
        # the columns (pad tokens from k) are already zeroed out during softmax
        if "padding" in attn_mask_type:
            attention_probs = attention_probs.masked_fill(attention_mask, 0)

4902
4903
4904
4905
4906
4907
4908
4909
4910
4911
4912
4913
4914
4915
4916
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
4917
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
4918
4919

        # change view [b * np, sq, sk]
4920
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
4921
4922
4923
4924
4925
4926
4927

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

4928
        if qkv_format == "sbhd":
4929
4930
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
4931

4932
4933
4934
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

4935
        if qkv_format == "bshd":
4936
4937
4938
4939
4940
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
4941
4942
4943
4944
4945
4946

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
4947
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
4948
4949

    @staticmethod
4950
4951
4952
4953
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
4954
        value_layer: torch.Tensor,
4955
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
4956
        # pylint: disable=missing-function-docstring
4957
4958
4959
4960
4961
4962
4963
4964
4965
4966
4967
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
4968
4969
4970
4971
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
4972
        dv: torch.Tensor,
4973
    ) -> Tuple[Union[torch.Tensor, None], ...]:
4974
        # pylint: disable=missing-function-docstring
4975
4976
4977
4978
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

4979

4980
def get_qkv_layout(
4981
4982
4983
4984
4985
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
4986
    """Get qkv layout.
4987

4988
4989
4990
4991
4992
4993
4994
4995
4996
4997
4998
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
4999
        `d` head size, and `t` the total number of tokens in a batch, i.e.
5000
5001
5002
5003
5004
5005
5006
5007
5008
5009
5010
5011
5012
5013
5014
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
5015
5016
5017
5018
5019
5020
5021
5022
5023
    q: torch.Tensor
        Query tensor. It may be different from input `q` as we try to fit tensors to
        a supported layout.
    k: torch.Tensor
        Key tensor. It may be different from input `k` as we try to fit tensors to
        a supported layout.
    v: torch.Tensor
        Value tensor. It may be different from input `v` as we try to fit tensors to
        a supported layout.
5024
    """
5025

5026
5027
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
5028

5029
    def run_iteratively(q, k, v):
5030
        # check data pointers
5031
5032
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
5033
        check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
5034
5035
5036
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

5037
5038
5039
5040
5041
5042
5043
        # check tensor shapes
        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = shape[:-1] == v.shape[:-1]

        # check tensor strides
5044
5045
        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
5046
5047
        check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
            sv / v.shape[-1] for sv in v.stride()[:-1]
5048
        )
5049

5050
5051
5052
5053
5054
5055
        # check tensor offsets for h3d and 3hd layouts
        prod_h_d = q.shape[-1] * q.shape[-2]
        check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v]))
        check_h3d_offsets = all(
            x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v])
        )
5056

5057
5058
5059
5060
5061
5062
        # check tensor offsets for hd_h2d and hd_2hd layouts
        prod_all_dims = [np.prod(x.shape) for x in [q, k]]
        offset = prod_all_dims[0] if check_ptrs_qkv else 0
        prod_h_d = k.shape[-1] * k.shape[-2]
        check_2hd_offsets = all(
            x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v])
5063
        )
5064
5065
        check_h2d_offsets = all(
            x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v])
5066
        )
5067

5068
5069
5070
5071
5072
5073
5074
5075
5076
5077
        # check tensor offsets for hd_hd_hd layouts
        check_hd_offsets_qkv = (
            all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v]))
            if check_ptrs_qkv
            else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v]))
        )
        check_hd_offsets_qk = (
            all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k]))
            if not check_ptrs_qkv and check_ptrs_qk
            else all(x.storage_offset() == 0 for i, x in enumerate([q, k]))
5078
        )
5079
5080
5081
5082
        check_hd_offsets_kv = (
            all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v]))
            if not check_ptrs_qkv and check_ptrs_kv
            else all(x.storage_offset() == 0 for i, x in enumerate([k, v]))
5083
        )
5084

5085
        if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
5086
            # sb3hd, bs3hd, t3hd
5087
            # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv
5088
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
5089
        elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
5090
            # sbh3d, bsh3d, th3d
5091
            # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv
5092
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
5093
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
5094
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
5095
5096
5097
            # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv
            # q and kv may be disjoint or consecutive in memory, and when consecutive, they may
            # have the same data pointer, i.e. check_ptrs_qkv=True
5098
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
5099
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
5100
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
5101
5102
5103
            # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv
            # q and kv may be disjoint or consecutive in memory, and when consecutive, they may
            # have the same data pointer, i.e. check_ptrs_qkv=True
5104
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
5105
5106
5107
5108
5109
        elif (
            check_strides_kv
            and check_shapes_kv
            and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk)
        ):
5110
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
5111
5112
5113
            # three chunks of memory, q, k and v, which may be disjoint or consecutive, and
            # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
            # check_ptrs_qk=True or check_ptrs_kv=True
5114
            qkv_layout = "_".join(list([qkv_format]) * 3)
5115
        else:
5116
            qkv_layout = "not_supported"
5117
5118
5119
5120

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
5121
    if qkv_layout == "not_supported":
5122
5123
5124
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
5125
    if qkv_layout == "not_supported":
5126
        raise RuntimeError("The provided qkv memory layout is not supported!")
5127

5128
    return qkv_layout, q, k, v
5129

5130

5131
def check_set_window_size(
5132
5133
5134
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
5135
5136
5137
5138
5139
5140
5141
5142
    """Check if sliding window size is compliant with attention mask type.
    If not, set it to the appropriate size.

         attn_mask_type                              |   window_size
    -------------------------------------------------------------------------
    no_mask, padding, arbitrary                      | (-1, -1) or (>=0, >=0)
    causal, padding_causal                           | (-1,  0) or (>=0, 0)
    causal_bottom_right, padding_causal_bottom_right | (-1,  0) or (>=0, 0)
5143
    """
5144
    orig_window_size = window_size
5145
    if "causal" in attn_mask_type:
5146
        if orig_window_size is None:
5147
            window_size = (-1, 0)
5148
5149
5150
        elif orig_window_size == (-1, -1) or (
            orig_window_size[0] >= 0 and orig_window_size[1] != 0
        ):
5151
5152
5153
5154
            window_size = (orig_window_size[0], 0)
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
5155
        elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
5156
5157
5158
5159
            assert False, (
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
    elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
5160
5161
5162
        if orig_window_size is None:
            window_size = (-1, -1)
        elif orig_window_size == (-1, 0):
5163
            window_size = (-1, -1)
5164
5165
5166
            warnings.warn(
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
5167
        elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
5168
5169
5170
5171
5172
            assert False, (
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
    else:
        assert False, "Invalid attn_mask_type: " + attn_mask_type
5173
    return window_size
5174

5175

5176
class FlashAttention(torch.nn.Module):
5177
    """Dot product attention, using HazyResearch flash-attn package:
5178
    https://github.com/Dao-AILab/flash-attention
5179
5180
5181
5182
    """

    def __init__(
        self,
5183
        softmax_scale: float,
5184
5185
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
5186
5187
        attention_type: str = "self",
        layer_number: Optional[int] = None,
5188
        deterministic: bool = False,
5189
5190
5191
    ) -> None:
        super().__init__()

5192
5193
5194
5195
5196
5197
5198
        if _flash_attn_is_installed:
            assert (
                _flash_attn_version >= _flash_attn_version_required
            ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
            assert (
                _flash_attn_version <= _flash_attn_max_version
            ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
5199

5200
        self.softmax_scale = softmax_scale
5201
5202
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
5203
5204
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
5205
        self.deterministic = deterministic
5206
5207
5208
5209
        self.logger = logging.getLogger("FlashAttention")
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
5210
5211
5212
5213
5214
5215

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5216
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5217
5218
5219
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5220
5221
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5222
        attn_mask_type: str = "causal",
5223
        window_size: Optional[Tuple[int, int]] = None,
5224
        alibi_slopes: Optional[torch.Tensor] = None,
5225
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
5226
        cp_global_ranks: List[int] = None,
5227
        cp_stream: torch.cuda.Stream = None,
5228
        cp_comm_type: str = "p2p",
5229
5230
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
5231
5232
5233
    ) -> torch.Tensor:
        """flash-attn fprop"""

5234
5235
5236
5237
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
5238
5239
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
5240
        ), "FlashAttention currently only supports CUDA tensors."
5241
5242
        assert (
            qkv_layout in QKVLayouts
5243
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
5244

5245
5246
5247
5248
5249
5250
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
5251
        context_parallel = cp_size > 1
5252

5253
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
5254

5255
5256
5257
5258
5259
5260
5261
5262
5263
5264
5265
5266
5267
        if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
            if qkv_format == "sbhd":
                # For now just 128, will make it more general in the future
                if (
                    query_layer.shape[-1] == 128
                    and query_layer.shape[0] * query_layer.shape[1] >= 512
                    and qkv_layout == "sbh3d"
                ):
                    query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                        query_layer, key_layer, value_layer
                    )
                else:
                    query_layer, key_layer, value_layer = [
5268
                        x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
5269
                    ]
5270
            if context_parallel:
5271
                query_layer, key_layer, value_layer = [
5272
5273
5274
5275
5276
                    x.contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        else:
            if qkv_format == "sbhd":
                query_layer._data, key_layer._data, value_layer._data = [
5277
                    x.transpose(0, 1)
5278
5279
                    for x in (query_layer._data, key_layer._data, value_layer._data)
                ]
5280
5281
5282
5283
                query_layer, key_layer, value_layer = [
                    Float8Tensor.make_like(x, data=x._data)
                    for x in (query_layer, key_layer, value_layer)
                ]
5284
            if context_parallel:
5285
5286
                query_layer._data, key_layer._data, value_layer._data = [
                    x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
5287
                ]
5288

5289
        batch_size = query_layer.shape[0]
5290

5291
        if qkv_format in ["sbhd", "bshd"]:
5292
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
5293
5294
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
5295
5296
5297

            if "padding" in attn_mask_type:
                assert not context_parallel, "Padding mask not supported with context parallelism!"
5298
5299
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
5300
                    x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
5301
5302
5303
5304
5305
5306
5307
                    for x in [query_layer, key_layer, value_layer]
                ]

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
5308
                    if cu_seqlens_q is None:
5309
5310
5311
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
5312
5313
5314
5315
5316
5317
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
5318
5319
                    )
                else:
5320
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
5321
5322
5323
5324
5325
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
5326
5327
5328
5329
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
5330
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
5331
            else:
5332
5333
5334
5335
5336
5337
5338
5339
5340
5341
5342
5343
5344
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
5345
5346
5347
5348
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
5349
5350
5351
5352
5353
5354
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
5355

5356
5357
5358
        if context_parallel and all(
            not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
        ):
5359
5360
5361
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
5362
            with self.attention_dropout_ctx():
5363
                output = attn_forward_func_with_cp(
5364
5365
5366
5367
5368
5369
5370
5371
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
5372
5373
                    cu_seqlens_q,
                    cu_seqlens_kv,
5374
                    self.attention_dropout if self.training else 0.0,
5375
5376
5377
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
5378
                    cp_comm_type,
5379
                    softmax_scale=self.softmax_scale,
5380
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
5381
                    attn_mask_type=attn_mask_type,
5382
                    deterministic=self.deterministic,
5383
                    window_size=window_size,
5384
5385
                )
        else:
5386
5387

            from .cpu_offload import CPUOffloadEnabled
5388

5389
5390
5391
5392
5393
5394
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

5395
            with self.attention_dropout_ctx():
5396
                fa_optional_forward_kwargs = {}
5397
5398
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
5399
5400
5401
5402
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
5403
5404
5405
5406
                fa_optional_forward_args_thd = []
                if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
                    func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
                else:
5407
5408
                    if _flash_attn_2_5_7_plus:
                        fa_optional_forward_kwargs["block_table"] = None
5409
5410
5411
5412
5413
5414
5415
5416
5417
5418
                    func = (
                        flash_attn_varlen_func
                        if not _use_flash_attn_3
                        else flash_attn_varlen_func_v3
                    )
                    fa_optional_forward_args_thd.append(cu_seqlens_q)
                    fa_optional_forward_args_thd.append(cu_seqlens_kv)
                    fa_optional_forward_args_thd.append(max_seqlen_q)
                    fa_optional_forward_args_thd.append(max_seqlen_kv)
                if _use_flash_attn_3:
5419
5420
5421
                    fa_3_optional_forward_kwargs = {}
                    fa_3_optional_forward_kwargs["window_size"] = window_size
                    fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
5422
                    activation_dtype = query_layer.dtype
5423
5424
5425
                    if fp8:
                        fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                        torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
5426
5427
5428
5429
5430
5431
5432
5433
5434
5435
5436

                        def convert_to_torch_float8(tensor, dtype):
                            out = torch.Tensor().to(device=tensor.device, dtype=dtype)
                            out.set_(
                                tensor._data.untyped_storage(),
                                tensor._data.storage_offset(),
                                tensor._data.shape,
                                tensor._data.stride(),
                            )
                            return out

5437
5438
5439
5440
5441
5442
5443
5444
                        if fp8_meta["recipe"].fp8_mha:
                            assert all(
                                isinstance(x, Float8Tensor)
                                for x in [query_layer, key_layer, value_layer]
                            ), "q/k/v must be Float8Tensors for FP8 MHA."
                            fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
                        else:
                            query_layer, key_layer, value_layer = (
5445
5446
                                Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward)
                                for x in [query_layer, key_layer, value_layer]
5447
                            )
5448
5449
5450
5451
5452
5453
5454
5455
5456
5457
5458
5459
5460
5461
5462
5463
5464
5465
5466
5467
5468
                        fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv
                        fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv
                        fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv
                        query_layer, key_layer, value_layer = (
                            convert_to_torch_float8(x, torch_dtype)
                            for x in [query_layer, key_layer, value_layer]
                        )
                    try:
                        output, _ = func(
                            query_layer,
                            key_layer,
                            value_layer,
                            *fa_optional_forward_args_thd,
                            softmax_scale=self.softmax_scale,
                            causal="causal" in attn_mask_type,
                            **fa_3_optional_forward_kwargs,
                        )
                    except TypeError as e:
                        if _flash_attn_3_0_0_beta:
                            e.args = (
                                e.args[0]
5469
                                + ". Please update your flash-attn v3 (beta) installation as it "
5470
5471
5472
5473
5474
                                + "may have added more supported arguments to its API. \n"
                                + _flash_attn_3_installation_steps,
                            ) + e.args[1:]
                        raise

5475
5476
5477
5478
5479
5480
5481
5482
5483
5484
5485
5486
5487
5488
5489
5490
5491
5492
5493
5494
5495
5496
5497
5498
5499
5500
                    if fp8 and fp8_meta["recipe"].fp8_mha:
                        output = cast_to_fp8(
                            output,
                            fp8_meta["scaling_fwd"],
                            META_O,
                            fp8_dtype_forward,
                        )
                        output = Float8Tensor(
                            data=output,
                            fp8_meta=fp8_meta,
                            fp8_meta_forward=True,
                            fp8_meta_index=META_O,
                            fp8_dtype=fp8_dtype_forward,
                            dtype=activation_dtype,
                        )
                else:
                    output = func(
                        query_layer,
                        key_layer,
                        value_layer,
                        *fa_optional_forward_args_thd,
                        self.attention_dropout if self.training else 0.0,
                        softmax_scale=self.softmax_scale,
                        causal="causal" in attn_mask_type,
                        **fa_optional_forward_kwargs,
                    )
5501

5502
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
5503
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
5504

5505
        if qkv_format == "sbhd":
5506
            # (bs)hd -> bs(hd) -> sb(hd)
5507
            if fp8 and fp8_meta["recipe"].fp8_mha:
5508
5509
5510
5511
5512
5513
                output = Float8Tensor.make_like(
                    output,
                    data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
                    .transpose(0, 1)
                    .contiguous(),
                )
5514
            else:
5515
                output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
5516
        elif qkv_format == "bshd":
5517
            # (bs)hd -> bs(hd)
5518
            output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
5519
        elif qkv_format == "thd":
5520
            # thd -> t(hd)
5521
            output = output.reshape(output.shape[0], -1)
5522

5523
        return output.contiguous()
5524

5525

5526
def _combine_tensors(
5527
5528
5529
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
5530
5531
5532
5533
5534
5535
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
5536
    new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
5537
    if isinstance(tensors[0], Float8Tensor):
5538
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
5539
5540
5541
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
5542
5543
5544
5545
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor)
5546
    else:
5547
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
5548
        combined_tensor.set_(
5549
5550
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
        )
5551
5552

    return combined_tensor
5553

5554

5555
5556
5557
5558
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
5559
5560
5561
5562
5563
    def forward(
        ctx,
        is_training,
        max_seqlen,
        cu_seqlens,
5564
        cu_seqlens_padded,
5565
5566
5567
5568
5569
5570
5571
5572
5573
        qkv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
5574
        window_size,
5575
5576
5577
5578
5579
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
5580
        deterministic,
5581
    ):
5582
        # pylint: disable=missing-function-docstring
5583
5584
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
5585
        if fp8:
5586
5587
            is_input_fp8 = isinstance(qkv, Float8Tensor)
            if is_input_fp8:
5588
5589
5590
5591
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
5592
            qkv_group = len(qkv_layout.split("_"))
5593
5594
5595
5596
            assert (
                qkv_group == 1
            ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}."
            if is_input_fp8:
5597
5598
5599
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
5600
5601
5602
                qkv_fp8 = cast_to_fp8(
                    qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(qkv.shape)
5603
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
5604
5605
5606
5607
5608
5609
5610
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
5611
                cu_seqlens_padded,
5612
5613
5614
5615
5616
5617
5618
5619
5620
5621
5622
5623
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
5624
5625
5626
5627
5628
5629
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5630
                window_size,
5631
5632
                rng_gen,
            )
5633
            if is_output_fp8:
5634
5635
                out_ret = Float8Tensor(
                    data=out_fp8,
5636
5637
5638
5639
5640
5641
5642
5643
5644
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
5645
5646
5647
5648
5649
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
5650
            out_save = out_ret
5651
5652
5653
5654
5655
5656
5657
5658
5659
5660
5661
5662
5663
5664
5665
5666
5667
5668
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                if is_input_fp8:
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                    qkv = cast_from_fp8(
                        qkv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[qkv.dtype],
                    ).view(qkv.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                        fp8_meta["scaling_fwd"],
                        META_O,
                        fp8_dtype_forward,
                        qkv_dtype,
                    ).view(out_fp8.shape)
5669
5670
5671
            fp8_tensors = (
                qkv_fp8,
                out_fp8,
5672
                fp8_meta["scaling_fwd"].scale.clone(),
5673
5674
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
5675
5676
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
5677
5678
5679
5680
5681
5682
5683
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
5684
                cu_seqlens_padded,
5685
5686
5687
5688
5689
5690
5691
5692
5693
5694
5695
5696
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
5697
5698
5699
5700
5701
5702
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5703
                window_size,
5704
5705
                rng_gen,
            )
5706
5707
5708
5709
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
5710
5711
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
5712
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
5713
        ctx.save_for_backward(
5714
            *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
5715
        )
5716
        ctx.fp8_meta = fp8_meta
5717
5718
5719
5720
5721
5722
5723
5724
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
5725
        ctx.window_size = window_size
5726
        ctx.fused_attention_backend = (
5727
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
5728
        )
5729
        ctx.use_FAv2_bwd = use_FAv2_bwd
5730
        ctx.deterministic = deterministic
5731

5732
        return out_ret
5733
5734
5735

    @staticmethod
    def backward(ctx, d_out):
5736
        # pylint: disable=missing-function-docstring
5737
        if ctx.is_output_fp8:
5738
5739
5740
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
5741
5742
5743
            d_out_f8tensor = d_out
            d_out = d_out._data

5744
        d_out = d_out.contiguous()
5745
5746
5747
5748
        (
            qkv,
            out,
            cu_seqlens,
5749
            cu_seqlens_padded,
5750
5751
5752
5753
5754
5755
            qkv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
5756
        rest = [None]
5757
5758
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
5759
        if ctx.use_FAv2_bwd:
5760
            softmax_lse, rng_state = aux_ctx_tensors
5761
            dqkv = torch.empty_like(qkv)
5762
5763
5764
            d_out, q, k, v, out = [
                maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
            ]
5765
            flash_attn_cuda_bwd(
5766
5767
5768
5769
5770
5771
5772
5773
5774
5775
5776
5777
5778
5779
5780
5781
5782
5783
5784
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dqkv[:, 0],
                dqkv[:, 1],
                dqkv[:, 2],
                cu_seqlens,
                cu_seqlens,
                ctx.max_seqlen,
                ctx.max_seqlen,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
5785
            )
5786
            dqkv = dqkv[..., : d_out.shape[-1]]
5787
        else:
5788
5789
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
5790
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
5791
                    fp8_dtype_backward = get_fp8_te_dtype(
5792
5793
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
5794
                    if ctx.is_output_fp8:
5795
                        d_out_fp8 = d_out
5796
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
5797
5798
5799
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
5800
5801
5802
5803
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
5804
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
5805
5806
5807
5808
5809
5810
5811
5812
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
5813
                        ctx.fused_attention_backend,
5814
                        cu_seqlens_padded,
5815
5816
5817
5818
5819
5820
5821
5822
5823
5824
5825
5826
5827
5828
5829
5830
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5831
5832
                        ctx.window_size,
                        ctx.deterministic,
5833
                    )
5834
                    if ctx.is_input_fp8:
5835
5836
                        dqkv = Float8Tensor(
                            data=dqkv_fp8,
5837
5838
5839
5840
5841
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
5842
                        )
5843
                    else:
5844
5845
5846
5847
5848
5849
5850
5851
5852
5853
                        dqkv_c_fp8 = dqkv_fp8.view(
                            -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                        )
                        dqkv = cast_from_fp8(
                            dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dqkv_fp8.shape)
5854
5855
5856
5857
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
5858
5859
5860
5861
5862
5863
5864
5865
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
5866
                        ctx.fused_attention_backend,
5867
                        cu_seqlens_padded,
5868
5869
5870
5871
5872
5873
5874
5875
5876
5877
5878
5879
5880
5881
5882
5883
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5884
5885
                        ctx.window_size,
                        ctx.deterministic,
5886
                    )
5887

5888
5889
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
5890
5891
5892
5893
5894
5895
5896
5897
5898
5899
5900
5901
5902
5903
5904
5905
5906
5907
5908
5909
5910
            return (
                None,
                None,
                None,
                None,
                dqkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
5911
5912
                None,
                None,
5913
            )
5914
        # else, return (dqkv, dbias)
5915
5916
5917
5918
5919
5920
5921
5922
5923
5924
5925
5926
5927
5928
5929
5930
5931
5932
5933
5934
5935
        return (
            None,
            None,
            None,
            None,
            dqkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
5936
5937
            None,
            None,
5938
        )
5939

5940

5941
5942
5943
5944
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
5945
5946
5947
5948
5949
5950
5951
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
5952
5953
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
5954
5955
5956
5957
5958
5959
5960
5961
5962
5963
        q,
        kv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
5964
        window_size,
5965
5966
5967
5968
5969
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
5970
        deterministic,
5971
    ):
5972
        # pylint: disable=missing-function-docstring
5973
5974
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
5975
        if fp8:
5976
5977
5978
            assert isinstance(kv, q.__class__), "q and kv must have the same type."
            is_input_fp8 = isinstance(q, Float8Tensor)
            if is_input_fp8:
5979
5980
5981
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
5982
            if is_input_fp8:
5983
5984
5985
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
5986
5987
                qkv_group = len(qkv_layout.split("_"))
                assert qkv_group == 2, (
5988
5989
                    "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, "
                    f"but found {qkv_layout}."
5990
5991
5992
5993
                )
                q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
                    q.shape
                )
5994
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
5995
5996
5997
                kv_fp8 = cast_to_fp8(
                    kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(kv.shape)
5998
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
5999
6000
6001
6002
6003
6004
6005
6006
6007
6008
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                kv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
6009
6010
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6011
6012
6013
6014
6015
6016
6017
6018
6019
6020
6021
6022
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
6023
6024
6025
6026
6027
6028
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6029
                window_size,
6030
6031
                rng_gen,
            )
6032
            if is_output_fp8:
6033
6034
                out_ret = Float8Tensor(
                    data=out_fp8,
6035
6036
6037
6038
6039
6040
6041
6042
6043
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
6044
6045
6046
6047
6048
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
6049
            out_save = out_ret
6050
6051
6052
6053
6054
6055
6056
6057
6058
6059
6060
6061
6062
6063
6064
6065
6066
6067
6068
6069
6070
6071
6072
6073
6074
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                if is_input_fp8:
                    q = cast_from_fp8(
                        q._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                    kv = cast_from_fp8(
                        kv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[kv.dtype],
                    ).view(kv.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                        fp8_meta["scaling_fwd"],
                        META_O,
                        fp8_dtype_forward,
                        qkv_dtype,
                    ).view(out_fp8.shape)
6075
6076
6077
6078
            fp8_tensors = (
                q_fp8,
                kv_fp8,
                out_fp8,
6079
                fp8_meta["scaling_fwd"].scale.clone(),
6080
6081
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
6082
6083
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
6084
6085
6086
6087
6088
6089
6090
6091
6092
6093
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                kv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
6094
6095
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6096
6097
6098
6099
6100
6101
6102
6103
6104
6105
6106
6107
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
6108
6109
6110
6111
6112
6113
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6114
                window_size,
6115
6116
                rng_gen,
            )
6117
6118
6119
6120
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
6121
6122
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
6123
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
6124
6125
6126
6127
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
6128
6129
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6130
6131
6132
            *fp8_tensors,
            *aux_ctx_tensors,
        )
6133
        ctx.fp8_meta = fp8_meta
6134
6135
6136
6137
6138
6139
6140
6141
6142
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
6143
        ctx.window_size = window_size
6144
        ctx.fused_attention_backend = (
6145
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
6146
        )
6147
        ctx.use_FAv2_bwd = use_FAv2_bwd
6148
        ctx.deterministic = deterministic
6149

6150
        return out_ret
6151
6152
6153

    @staticmethod
    def backward(ctx, d_out):
6154
        # pylint: disable=missing-function-docstring
6155
        if ctx.is_output_fp8:
6156
6157
6158
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
6159
6160
6161
            d_out_f8tensor = d_out
            d_out = d_out._data

6162
        d_out = d_out.contiguous()
6163
6164
6165
6166
6167
6168
        (
            q,
            kv,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
6169
6170
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6171
6172
6173
6174
6175
6176
6177
            q_fp8,
            kv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
6178
        rest = [None]
6179
6180
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
6181
        if ctx.use_FAv2_bwd:
6182
            softmax_lse, rng_state = aux_ctx_tensors
6183
6184
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
6185
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
6186
            flash_attn_cuda_bwd(
6187
6188
6189
6190
6191
6192
6193
6194
6195
6196
6197
6198
6199
6200
6201
6202
6203
6204
6205
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dkv[:, 0],
                dkv[:, 1],
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
6206
            )
6207
6208
            dq = dq[..., : d_out.shape[-1]]
            dkv = dkv[..., : d_out.shape[-1]]
6209
        else:
6210
6211
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
6212
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
6213
                    fp8_dtype_backward = get_fp8_te_dtype(
6214
6215
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
6216
                    if ctx.is_output_fp8:
6217
                        d_out_fp8 = d_out
6218
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
6219
6220
6221
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
6222
6223
6224
6225
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
6226
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
6227
6228
6229
6230
6231
6232
6233
6234
6235
6236
6237
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        kv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
6238
                        ctx.fused_attention_backend,
6239
6240
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6241
6242
6243
6244
6245
6246
6247
6248
6249
6250
6251
6252
6253
6254
6255
6256
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6257
6258
                        ctx.window_size,
                        ctx.deterministic,
6259
                    )
6260
                    if ctx.is_input_fp8:
6261
6262
                        dq = Float8Tensor(
                            data=dq_fp8,
6263
6264
6265
6266
6267
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6268
6269
6270
                        )
                        dkv = Float8Tensor(
                            data=dkv_fp8,
6271
6272
6273
6274
6275
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6276
                        )
6277
6278
6279
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6280
6281
6282
6283
6284
6285
6286
6287
6288
6289
6290
6291
6292
6293
6294
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(
                            -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                        )
                        dkv = cast_from_fp8(
                            dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dkv_fp8.shape)
6295
6296
6297
6298
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
6299
6300
6301
6302
6303
6304
6305
6306
6307
6308
6309
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        kv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
6310
                        ctx.fused_attention_backend,
6311
6312
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6313
6314
6315
6316
6317
6318
6319
6320
6321
6322
6323
6324
6325
6326
6327
6328
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6329
6330
                        ctx.window_size,
                        ctx.deterministic,
6331
                    )
6332

6333
6334
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
6335
6336
6337
6338
6339
6340
6341
6342
6343
6344
6345
6346
6347
6348
6349
6350
6351
6352
6353
6354
6355
6356
6357
6358
6359
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
6360
6361
                None,
                None,
6362
            )
6363
        # else, return (dqkv, dbias)
6364
6365
6366
6367
6368
6369
6370
6371
6372
6373
6374
6375
6376
6377
6378
6379
6380
6381
6382
6383
6384
6385
6386
6387
6388
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
6389
6390
            None,
            None,
6391
6392
        )

6393

6394
6395
6396
6397
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
6398
6399
6400
6401
6402
6403
6404
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
6405
6406
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
6407
6408
6409
6410
6411
6412
6413
6414
6415
6416
6417
        q,
        k,
        v,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
6418
        window_size,
6419
6420
6421
6422
6423
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
6424
        deterministic,
6425
    ):
6426
        # pylint: disable=missing-function-docstring
6427
6428
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
6429
6430
6431
        if fp8:
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
6432
6433
6434
6435
6436
            assert isinstance(k, q.__class__) and isinstance(
                v, q.__class__
            ), "q, k, and v must have the same type."
            is_input_fp8 = isinstance(q, Float8Tensor)
            if is_input_fp8:
6437
6438
6439
6440
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6441
                qkv_group = len(qkv_layout.split("_"))
6442
                if qkv_group == 1:
6443
6444
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
6445
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
6446
6447
6448
6449
                    qkv_fp8 = cast_to_fp8(
                        qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1])
6450
6451
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
6452
6453
6454
6455
6456
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
6457
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
6458
6459
6460
6461
                    kv_fp8 = cast_to_fp8(
                        kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1])
6462
6463
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
6464
6465
6466
6467
6468
6469
6470
6471
6472
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    k_fp8 = cast_to_fp8(
                        k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(k.shape)
                    v_fp8 = cast_to_fp8(
                        v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(v.shape)
6473
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
6474
6475
6476
6477
6478
6479
6480
6481
6482
6483
6484
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
6485
6486
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6487
6488
6489
6490
6491
6492
6493
6494
6495
6496
6497
6498
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
6499
6500
6501
6502
6503
6504
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6505
                window_size,
6506
6507
                rng_gen,
            )
6508
            if is_output_fp8:
6509
6510
                out_ret = Float8Tensor(
                    data=out_fp8,
6511
6512
6513
6514
6515
6516
6517
6518
6519
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
6520
6521
6522
6523
6524
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
6525
6526
            out_save = out_ret

6527
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
6528
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6529
6530
6531
6532
6533
6534
6535
6536
6537
6538
6539
6540
6541
6542
6543
6544
6545
6546
6547
6548
6549
6550
6551
6552
6553
6554
6555
6556
6557
6558
6559
6560
6561
6562
6563
6564
6565
6566
6567
6568
6569
6570
6571
6572
6573
6574
6575
6576
6577
6578
6579
6580
6581
6582
6583
6584
6585
6586
6587
6588
                if is_input_fp8:
                    qkv_group = len(qkv_layout.split("_"))
                    if qkv_group == 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                        qkv_no_fp8 = cast_from_fp8(
                            qkv_c._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[qkv.dtype],
                        ).view(qkv.shape)
                        q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1])
                        q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                    if qkv_group == 2:
                        q = cast_from_fp8(
                            q._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[q.dtype],
                        ).view(q.shape)
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                        kv_no_fp8 = cast_from_fp8(
                            kv_c._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[kv.dtype],
                        ).view(kv.shape)
                        k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1])
                        k, v = [x.squeeze(dim) for x in [k, v]]
                    if qkv_group == 3:
                        q = cast_from_fp8(
                            q._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[q.dtype],
                        ).view(q.shape)
                        k = cast_from_fp8(
                            k._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[k.dtype],
                        ).view(k.shape)
                        v = cast_from_fp8(
                            v._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[v.dtype],
                        ).view(v.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
6589
                        fp8_meta["scaling_fwd"],
6590
                        META_O,
6591
                        fp8_dtype_forward,
6592
6593
                        qkv_dtype,
                    ).view(out_fp8.shape)
6594
6595
6596
6597
6598
6599

            fp8_tensors = (
                q_fp8,
                k_fp8,
                v_fp8,
                out_fp8,
6600
                fp8_meta["scaling_fwd"].scale.clone(),
6601
6602
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
6603
6604
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd(
6605
6606
6607
6608
6609
6610
6611
6612
6613
6614
6615
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
6616
6617
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6618
6619
6620
6621
6622
6623
6624
6625
6626
6627
6628
6629
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
6630
6631
6632
6633
6634
6635
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6636
                window_size,
6637
6638
                rng_gen,
            )
6639
6640
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
6641

6642
6643
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

6644
        from .cpu_offload import CPUOffloadEnabled
6645

6646
        if CPUOffloadEnabled:
6647
6648
6649
6650
6651
6652
6653
            if ctx.fp8:
                tensor_list = fp8_tensors
            else:
                tensor_list = [q, k, v, out_save]

            tensor_list.extend(aux_ctx_tensors)

6654
            qkv_layout = "sbhd_sbhd_sbhd"
6655
6656
6657
6658
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

6659
6660
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
6661
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
6662
6663
6664
6665
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
6666
6667
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6668
6669
6670
            *fp8_tensors,
            *aux_ctx_tensors,
        )
6671
        ctx.fp8_meta = fp8_meta
6672
6673
6674
6675
6676
6677
6678
6679
6680
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
6681
        ctx.window_size = window_size
6682
        ctx.fused_attention_backend = (
6683
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
6684
        )
6685
        ctx.use_FAv2_bwd = use_FAv2_bwd
6686
        ctx.deterministic = deterministic
6687

6688
        return out_ret
6689
6690
6691

    @staticmethod
    def backward(ctx, d_out):
6692
        # pylint: disable=missing-function-docstring
6693
        if ctx.is_output_fp8:
6694
6695
6696
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
6697
6698
6699
            d_out_f8tensor = d_out
            d_out = d_out._data

6700
        d_out = d_out.contiguous()
6701
6702
6703
6704
6705
6706
6707
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
6708
6709
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6710
6711
6712
6713
6714
6715
6716
6717
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
6718
6719
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
6720
        rest = [None]
6721
        if ctx.use_FAv2_bwd:
6722
            softmax_lse, rng_state = aux_ctx_tensors
6723
6724
6725
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
6726
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
6727
            flash_attn_cuda_bwd(
6728
6729
6730
6731
6732
6733
6734
6735
6736
6737
6738
6739
6740
6741
6742
6743
6744
6745
6746
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
6747
            )
6748
6749
6750
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
6751
        else:
6752
6753
6754
6755
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
6756
6757
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
6758
                    if ctx.is_output_fp8:
6759
                        d_out_fp8 = d_out
6760
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
6761
6762
6763
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
6764
6765
6766
6767
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
6768
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
6769
6770
6771
6772
6773
6774
6775
6776
6777
6778
6779
6780
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
6781
                        ctx.fused_attention_backend,
6782
6783
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6784
6785
6786
6787
6788
6789
6790
6791
6792
6793
6794
6795
6796
6797
6798
6799
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6800
6801
                        ctx.window_size,
                        ctx.deterministic,
6802
                    )
6803

6804
                    if ctx.is_input_fp8:
6805
6806
                        dq = Float8Tensor(
                            data=dq_fp8,
6807
6808
6809
6810
6811
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6812
6813
6814
                        )
                        dk = Float8Tensor(
                            data=dk_fp8,
6815
6816
6817
6818
6819
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6820
6821
6822
                        )
                        dv = Float8Tensor(
                            data=dv_fp8,
6823
6824
6825
6826
6827
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6828
                        )
6829
                    else:
6830
                        qkv_group = len(ctx.qkv_layout.split("_"))
6831
                        if qkv_group == 1:
6832
6833
6834
6835
6836
6837
6838
6839
6840
6841
6842
6843
6844
                            dim = ctx.qkv_layout.find("3")
                            dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(
                                -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                            )
                            dqkv = cast_from_fp8(
                                dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1])
6845
6846
6847
6848
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6849
6850
6851
6852
6853
6854
6855
6856
6857
6858
6859
6860
6861
6862
6863
6864
6865
6866
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
                            dkv = cast_from_fp8(
                                dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1])
6867
6868
6869
6870
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6871
6872
6873
6874
6875
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
6876
6877
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
6878
6879
6880
6881
6882
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dk_fp8.shape)
6883
6884
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
6885
6886
6887
6888
6889
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dv_fp8.shape)
6890
6891
6892
6893
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
6894
6895
6896
6897
6898
6899
6900
6901
6902
6903
6904
6905
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
6906
                        ctx.fused_attention_backend,
6907
6908
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6909
6910
6911
6912
6913
6914
6915
6916
6917
6918
6919
6920
6921
6922
6923
6924
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6925
6926
                        ctx.window_size,
                        ctx.deterministic,
6927
                    )
6928

6929
6930
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
6931
6932
6933
6934
6935
6936
6937
6938
6939
6940
6941
6942
6943
6944
6945
6946
6947
6948
6949
6950
6951
6952
6953
6954
6955
6956
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
6957
6958
                None,
                None,
6959
            )
6960
        # else, return (dqkv, dbias)
6961
6962
6963
6964
6965
6966
6967
6968
6969
6970
6971
6972
6973
6974
6975
6976
6977
6978
6979
6980
6981
6982
6983
6984
6985
6986
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
6987
6988
            None,
            None,
6989
        )
6990

6991

6992
class FusedAttention(torch.nn.Module):
6993
6994
6995
6996
6997
6998
6999
7000
7001
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

7002
7003
7004
7005
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
7006
    | attn_type     | self/cross              | self/cross                     |
7007
    | qkv_layout    |                         |                                |
7008
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
7009
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
7010
7011
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
7012
7013
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
7014
    | dropout       | yes                     | yes                            |
7015
7016
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
7017
    | output dtype  | fp16/bf16               | fp16/bf16                      |
7018
7019
7020
7021
    """

    def __init__(
        self,
7022
        softmax_scale: float,
7023
7024
7025
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
7026
7027
        layer_number: Optional[int] = None,
        deterministic: bool = False,
7028
7029
7030
    ) -> None:
        super().__init__()

7031
        self.softmax_scale = softmax_scale
7032
7033
7034
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
7035
7036
7037
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
7038
        self.layer_number = 1 if layer_number is None else layer_number
7039
        self.deterministic = deterministic
7040

7041
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
7042
7043
            """
            Temporarily remove fused_attention._extra_state as a missing key
7044
            or an unexpected key when loading Transformer Engine checkpoints.
7045
7046
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
7047
            phased out in Transformer Engine 2.0.
7048
7049
            """
            for key in incompatible_keys.missing_keys:
7050
                if "fused_attention._extra_state" in key:
7051
                    incompatible_keys.missing_keys.remove(key)
7052
7053
7054
7055
7056
7057
7058
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
7059

7060
7061
        self.register_load_state_dict_post_hook(remove_extra_states_check)

7062
    @no_torch_dynamo()
7063
7064
7065
7066
7067
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
7068
7069
7070
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
7071
7072
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
7073
7074
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
7075
        attn_mask_type: str = "causal",
7076
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
7077
        window_size: Optional[Tuple[int, int]] = None,
7078
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
7079
7080
7081
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
7082
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
7083
7084
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
7085
        cp_comm_type: str = "p2p",
7086
7087
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
7088
7089
    ) -> torch.Tensor:
        """fused attention fprop"""
7090
7091
7092
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
7093
7094
7095
7096
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
7097
7098
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
7099
        ), "FusedAttention only supports CUDA tensors."
7100
7101
        assert (
            qkv_layout in QKVLayouts
7102
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
7103

7104
7105
7106
7107
7108
7109
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
7110
        context_parallel = cp_size > 1
7111

7112
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
7113

7114
7115
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
7116
                batch_size, max_seqlen_q, max_seqlen_kv = (
7117
7118
7119
7120
7121
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
7122
                batch_size, max_seqlen_q, max_seqlen_kv = (
7123
7124
7125
7126
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
7127
7128
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
7129
            if "padding" in attn_mask_type:
7130
7131
                assert not context_parallel, "Padding mask not supported with context parallelism!"

7132
7133
7134
7135
7136
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
7137
                    if self.attention_type == "self":
7138
7139
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
7140
                    else:
7141
7142
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
7143
            else:
7144
7145
7146
7147
7148
7149
7150
7151
7152
7153
7154
7155
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
7156
7157
7158
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
7159
7160
7161
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
7162
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
7163
7164
7165
7166

        if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
7167
7168
7169

        qkv_dtype = TE_DType[query_layer.dtype]

7170
7171
7172
7173
7174
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
7175

7176
7177
7178
7179
7180
7181
7182
7183
7184
7185
7186
        if fp8:
            assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                " is required for FP8 attention!"
            )
            assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
            assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
                "Amax reduction across TP+CP group is necessary when using context parallelism with"
                " FP8!"
            )

7187
        if context_parallel:
7188
            assert (
7189
7190
                fp8
                or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
7191
7192
7193
7194
7195
7196
7197
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
7198
7199
7200
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
7201
7202
7203
7204
7205
7206
7207
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
7208
7209
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
7210
                    self.attention_dropout if self.training else 0.0,
7211
7212
7213
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
7214
                    cp_comm_type,
7215
                    softmax_scale=self.softmax_scale,
7216
                    qkv_format=qkv_format,
7217
                    attn_mask_type=attn_mask_type,
7218
7219
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
7220
                    deterministic=self.deterministic,
7221
                    use_fused_attention=True,
7222
                    window_size=window_size,
7223
7224
                    fp8=fp8,
                    fp8_meta=fp8_meta,
7225
7226
                )
        else:
7227
7228
7229
7230
7231
7232
7233
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
7234
7235
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
7236
7237
7238
7239
7240
7241
7242
7243
7244
7245
7246
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
7247
                    window_size,
7248
7249
7250
7251
7252
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
7253
                    self.deterministic,
7254
                )
7255

7256
7257
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
7258
7259


7260
class DotProductAttention(TransformerEngineBaseModule):
7261
7262
7263
7264
7265
7266
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

7267
        Argument :attr:`attention_mask` in the `forward` call is only used when
7268
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
7269
7270
7271

    .. warning::

7272
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
7273
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
7274
7275
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
7276

7277
7278
7279
7280
7281
7282
7283
    .. note::

        Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
        As the FP8 attention support expands from one backend to multiple backends, the location
        of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).


7284
7285
7286
7287
    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
7288
7289
7290
    kv_channels : Union[int, Tuple[int, int]]
                the head size in key and value tensors. If the same, :attr:`kv_channels` can be
                an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
7291
7292
7293
7294
7295
7296
7297
7298
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
7299
7300
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
7301
    attn_mask_type: str, default = `causal`
7302
                   type of attention mask passed into softmax operation, options are "`no_mask`",
7303
7304
7305
7306
7307
7308
7309
7310
7311
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
7312
                   "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
7313
7314
7315
7316
7317
7318
7319
7320
7321
7322
7323
7324
7325
7326
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
7327
7328
7329
7330
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
7331
7332
7333
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
7334
                be overridden by :attr:`window_size` in `forward` as well.
7335
7336
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
7337
7338
7339
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
7340
7341
7342
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
7343
               `h` the number of heads, `d` head size, and `t` the total number of tokens
7344
7345
7346
7347
7348
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
7349
               For that, please use `get_qkv_layout` to gain the layout information.
7350
7351
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
7352
                `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
7353
7354
7355
7356
7357
7358
7359
7360
7361

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
7362
    cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
7363
              context parallel process group.
7364
7365
7366
              ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
              List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
              and cp_group[1] are for a2a and p2p communications respectively.
7367
7368
7369
7370
7371
7372
7373
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
7374
    cp_comm_type : str, default = `p2p`
7375
                  inter-gpu communication type for context parallelism.
7376
                  Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
7377
7378
7379
7380
7381
7382
                  "p2p": Exchange KV chunks with P2P communications in ring topology.
                         P2P is async and can be overlapped with attention compute.
                  "all_gather": All-gather to get full sequence of KV before attention.
                                The all-gather is not async, and cannot be overlapped.
                  "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                         group, and gather to get full sequence of QKV.
7383
7384
7385
                  "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                  across each CP sub-group (e.g., via NVLink), then exchanging KV with
                  p2p between sub-groups (e.g., via IBLink).
7386
7387
7388
7389
7390
    """

    def __init__(
        self,
        num_attention_heads: int,
7391
        kv_channels: Union[int, Tuple[int, int]],
7392
        num_gqa_groups: Optional[int] = None,
7393
        attention_dropout: float = 0.0,
7394
        qkv_format: str = "sbhd",
7395
        attn_mask_type: str = "causal",
7396
        window_size: Optional[Tuple[int, int]] = None,
7397
7398
7399
7400
7401
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
7402
        attention_type: str = "self",
7403
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
7404
        cp_global_ranks: List[int] = None,
7405
        cp_stream: torch.cuda.Stream = None,
7406
        cp_comm_type: str = "p2p",
7407
        softmax_scale: Optional[float] = None,
7408
7409
7410
    ) -> None:
        super().__init__()

7411
        self.logger = logging.getLogger("DotProductAttention")
7412
7413
7414
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
7415
        self.qkv_format = qkv_format
7416
        attn_mask_type = attn_mask_type.replace(",", "_")
7417
7418
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
7419
        self.attn_mask_type = attn_mask_type
7420
        self.window_size = check_set_window_size(attn_mask_type, window_size)
7421
7422
7423
7424
7425
7426
7427
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
7428
        self.get_rng_state_tracker = get_rng_state_tracker
7429
        self.num_attention_heads = num_attention_heads
7430
        self.layer_number = 1 if layer_number is None else layer_number
7431
7432
7433
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
7434
        self.cp_comm_type = cp_comm_type
7435

7436
7437
7438
7439
7440
7441
        self.hidden_size_per_attention_head_k = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[0]
        )
        self.hidden_size_per_attention_head_v = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[1]
        )
7442

7443
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
7444
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
7445

7446
7447
7448
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
7449

7450
        self.rng_states_tracker = None
7451
7452
7453
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
7454
7455
7456
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
7457

7458
        if softmax_scale is None:
7459
7460
7461
            softmax_scale = 1.0 / math.sqrt(
                kv_channels if isinstance(kv_channels, int) else kv_channels[0]
            )
7462

7463
7464
7465
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
7466
        )
7467
7468
7469
7470
7471
7472
7473
7474
7475
7476
7477
7478
7479
7480
7481
7482
7483
7484
7485
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
7486

7487
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
7488
7489
7490
7491

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

7492
7493
7494
7495
7496
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

7497
7498
7499
7500
7501
7502
7503
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
7504

7505
        # Instantiating three types since use of flash-attn and FusedAttention
7506
        # might be ruled out due to forward inputs.
7507
7508
7509
7510
7511
7512
7513
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
7514

7515
        self.unfused_attention = UnfusedDotProductAttention(
7516
7517
7518
7519
            softmax_scale,
            attention_type=attention_type,
            **attn_kwargs,
            layer_number=layer_number,
7520
        )
7521

7522
7523
7524
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
7525
7526
            when loading older Transformer Engine checkpoints. Will phase out
            this hook in Transformer Engine 2.0.
7527
7528
7529
7530
7531
7532
7533
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

7534
7535
7536
7537
7538
7539
7540
7541
7542
7543
7544
7545
7546
7547
7548
7549
7550
7551
7552
7553
7554
7555
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        """
        This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
        metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
        `core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
        <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
        """
        fused_attn_key = False
        dot_product_attn_key = False
        for k in state_dict.keys():
            if "core_attention.fused_attention._extra_state" in k:
                fused_attn_key = True
            if "core_attention._extra_state" in k:
                dot_product_attn_key = True
        if fused_attn_key and not dot_product_attn_key:
            prefix = prefix + "fused_attention."
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )

7556
7557
7558
7559
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
7560
        **forward_kwargs: Dict[str, Any],
7561
7562
7563
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

7564
7565
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
7566
7567
7568

        hidden_states = checkpoint(
            custom_forward,
7569
7570
7571
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
7572
            *forward_args,
7573
            **forward_kwargs,
7574
7575
7576
7577
        )

        return hidden_states

7578
7579
    def set_context_parallel_group(
        self,
7580
        cp_group: Union[dist_group_type, List[dist_group_type], None],
7581
7582
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
7583
        cp_comm_type: str = "p2p",
7584
    ) -> None:
7585
7586
7587
7588
7589
7590
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
7591
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
7592
                  context parallel process group.
7593
7594
7595
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
7596
7597
7598
7599
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
7600
        cp_comm_type : str, default = `p2p`
7601
                      inter-gpu communication type for context parallelism.
7602
                      Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
7603
7604
7605
7606
7607
7608
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
7609
7610
7611
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
7612
        """
7613
7614
7615
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
7616
        self.cp_comm_type = cp_comm_type
7617

7618
    @no_torch_dynamo(recursive=False)
7619
7620
7621
7622
7623
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
7624
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
7625
7626
7627
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
7628
7629
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
7630
7631
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
7632
        attn_mask_type: Optional[str] = None,
7633
        window_size: Optional[Tuple[int, int]] = None,
7634
        checkpoint_core_attention: bool = False,
7635
7636
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
7637
        alibi_slopes: Optional[torch.Tensor] = None,
7638
        fast_zero_fill: bool = True,
7639
        inference_params: Optional[InferenceParams] = None,
7640
        is_first_microbatch: Optional[bool] = None,
7641
7642
7643
7644
7645
7646
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

7647
7648
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
7649

7650
7651
        .. note::

7652
7653
7654
7655
7656
7657
7658
7659
7660
7661
7662
7663
7664
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
7665
            and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
7666
7667
7668
7669
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
7670
7671
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
7672
            optimizations in FusedAttention. When unset, Transformer Engine determines the code path
7673
7674
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
7675

7676
7677
7678
7679
7680
7681
7682
7683
7684
7685
7686
7687
7688
7689
7690
7691
7692
7693
7694
7695
7696
7697
7698
7699
7700
7701
7702
7703
7704
7705
7706
7707
7708
7709
7710
7711
7712
7713
7714
7715
7716
7717
7718
7719
7720
7721
7722
7723
7724
7725
7726
7727
7728
7729
        .. note::
            .. _cu_seqlens note:

            When training data has variable sequence lengths, users have two options.

            1. Manipulate the data and pad all sequences to the same length. Use
               :attr:`qkv_format` = {"bshd", "sbhd"} and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
               (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
               the real sequence length information. For example, a batch of 3 sequences
               [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative
               sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

            2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
               as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed
               without any padding, and the sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

               In certain use cases, a varying number of identifier tokens are inserted between
               sequences. These tokens do not participate in the attention calculation.
               :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
               in such cases to correctly identify the start and end of each sequence in a batch.
               For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and
               :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13]
               for self-attention.

        .. note::
            .. _max_seqlen note:

            When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch.
            :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of
            :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will
            infer them as such.

            When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and
            :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch.
            When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`.
            This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this
            overhead, users are recommended to obtain the maximum sequence lengths from the data loaders
            and pass them in.

            - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch,
              dynamic shapes need to be supported for tensor construction. FlashAttention and
              UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static
              to create graphs before performance heuristics analysis. To reduce the number of graphs created
              per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size,
              :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of
              :attr:`query_layer`, "t" dimension of :attr:`key_layer`}.

7730
7731
7732
7733
7734
7735
7736
7737
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
7738
7739
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
7740
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
7741
7742
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
7743
7744
7745
7746
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
7747
7748
7749
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
7750
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
7751
                   with shape [batch_size + 1] and dtype torch.int32.
7752
                   See :ref:`note<cu_seqlens note>` for more details.
7753
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
7754
7755
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
7756
                   See :ref:`note<cu_seqlens note>` for more details.
7757
7758
7759
7760
7761
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
7762
                   See :ref:`note<cu_seqlens note>` for more details.
7763
7764
7765
7766
7767
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
7768
                   See :ref:`note<cu_seqlens note>` for more details.
7769
7770
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
7771
                      See :ref:`note<max_seqlen note>` for more details.
7772
7773
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
7774
                       See :ref:`note<max_seqlen note>` for more details.
7775
7776
7777
7778
7779
7780
7781
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
7782
        window_size: Optional[Tuple[int, int]], default = `None`
7783
                    Sliding window size for local attention.
7784
7785
7786
7787
7788
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
7789
        core_attention_bias_type: str, default = `no_bias`
7790
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
7791
        core_attention_bias: Optional[torch.Tensor], default = `None`
7792
7793
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
7794
7795
7796
7797
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
7798
        fast_zero_fill: bool, default = `True`
7799
                    Whether to use the fast path to set output tensors to 0 or not.
7800
7801
7802
7803
7804
7805
7806
7807
7808
7809
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
7810
7811
7812
7813
7814
7815
7816
7817
7818
7819
7820
7821
7822
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
7823
        """
7824
7825
7826
7827
7828
7829
7830
7831
7832
7833
7834
        with self.prepare_forward(
            query_layer,
            is_first_microbatch,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:

            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
7835
                        self.logger.warning(
7836
7837
7838
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
7839
7840
7841
7842
7843
7844
7845
7846
7847
7848
7849

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
7850

7851
7852
7853
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
7854
7855
7856
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
7857
7858
7859
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
7860
7861
7862
7863
7864
7865
7866
7867
            assert (
                key_layer.shape[-1] == self.hidden_size_per_attention_head_k
            ), f"Keys have head_dim = {key_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
            assert (
                value_layer.shape[-1] == self.hidden_size_per_attention_head_v
            ), f"Values have head_dim = {value_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
7868

7869
7870
7871
7872
7873
7874
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
7875
            assert (
7876
7877
7878
7879
7880
7881
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
7882

7883
7884
7885
7886
            if window_size is None:
                window_size = self.window_size
            window_size = check_set_window_size(attn_mask_type, window_size)

7887
7888
7889
7890
7891
7892
7893
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
7894

7895
7896
            if qkv_format is None:
                qkv_format = self.qkv_format
7897

7898
7899
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
7900

7901
7902
7903
7904
7905
                # convert causal to causal_bottom_right in inference when KV-caching is in use
                # so users can run with the same attn_mask_type for training and inference
                if attn_mask_type in ["causal", "padding_causal"]:
                    attn_mask_type = attn_mask_type + "_bottom_right"

7906
7907
7908
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
7909

7910
7911
7912
7913
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
7914

7915
7916
7917
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
7918

7919
7920
7921
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
7922

7923
7924
7925
7926
7927
7928
7929
7930
7931
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
7932

7933
7934
7935
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
7936

7937
7938
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
7939
7940

            assert (
7941
7942
7943
7944
7945
7946
7947
7948
7949
7950
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
7951
                assert all(
7952
7953
7954
7955
7956
7957
7958
7959
7960
7961
7962
7963
7964
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
7965
                batch_size = len(cu_seqlens_q) - 1
7966
                if max_seqlen_q is None:
7967
7968
7969
7970
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
7971
                    max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
7972
                if max_seqlen_kv is None:
7973
7974
7975
7976
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
7977
                    max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
7978

7979
7980
7981
7982
7983
7984
            cp_size = 1
            if isinstance(self.cp_group, dist_group_type):
                cp_size = get_distributed_world_size(self.cp_group)
            elif isinstance(self.cp_group, list):
                for group in self.cp_group:
                    cp_size *= get_distributed_world_size(group)
7985
7986
            context_parallel = cp_size > 1

7987
            if qkv_format in ["sbhd", "bshd"]:
7988
                assert all(
7989
7990
7991
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
7992
7993
                    max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
7994
                    batch_size = query_layer.shape[1]
7995
                else:
7996
7997
                    max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
7998
                    batch_size = query_layer.shape[0]
7999
8000
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
8001
8002
8003
8004
8005
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
8006
                        the sequence dimension in 'query_layer'!"""
8007
8008
8009
8010
8011
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
8012
                        the sequence dimension in 'key_layer' and 'value_layer'!"""
8013
8014
8015
8016
8017
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
8018
                        if self.attention_type == "self":
8019
8020
8021
8022
8023
8024
8025
8026
8027
8028
8029
8030
8031
8032
8033
8034
                            cu_seqlens_q = get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                    else:
                        cu_seqlens_q = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
                        cu_seqlens_kv = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
8035

8036
8037
8038
8039
8040
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
8041
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
8042
8043
8044
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
8045
                qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
8046
8047
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
8048

8049
8050
8051
8052
8053
8054
8055
8056
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
8057
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
8058
8059
8060
8061
8062
8063
8064
8065
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
8066
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
8067
8068
8069
8070
8071
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

8072
8073
            core_attention_bias_shape = None
            if core_attention_bias is not None:
8074
                if (
8075
8076
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
8077
                ):
8078
8079
8080
8081
8082
8083
8084
8085
8086
8087
8088
8089
8090
8091
8092
8093
8094
8095
8096
8097
8098
8099
8100
8101
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

            pad_between_seqs = (
                cu_seqlens_q_padded is not None
                and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
            ) or (
                cu_seqlens_kv_padded is not None
                and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
            )
8102

8103
            attention_params = AttentionParams(
8104
8105
8106
8107
8108
8109
8110
8111
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
8112
8113
                head_dim_qk=query_layer.shape[-1],
                head_dim_v=value_layer.shape[-1],
8114
8115
8116
8117
8118
8119
8120
8121
8122
8123
8124
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
8125
8126
                deterministic=self.deterministic,
                is_training=self.training,
8127
8128
8129
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
8130
            global _attention_backends, _use_flash_attn_3
8131
8132
8133
8134
8135
8136
8137
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
8138
                _use_flash_attn_3 = _flash_attn_3_is_installed
8139
8140
8141
8142
8143
8144
8145
8146
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
                ) = get_attention_backend(attention_params)
                if use_flash_attention:
8147
8148
                    self.logger.info(
                        "Running with FlashAttention backend (version %s)",
8149
                        _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version,
8150
                    )
8151
8152
8153
8154
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
8155
                    )
8156
8157
8158
8159
8160
8161
8162
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
8163

8164
8165
8166
8167
8168
8169
8170
8171
8172
8173
8174
8175
8176
8177
8178
8179
8180
8181
8182
8183
8184
8185
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
8186
                    cp_comm_type=self.cp_comm_type,
8187
8188
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
8189
8190
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
8191
                )
8192

8193
            if use_fused_attention:
8194
8195
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
8196
8197
8198
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
8199
8200
8201
8202
8203
8204
8205
                    fu_core_attention_bias_type = "post_scale_bias"
                    _, fu_core_attention_bias = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
8206
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
8207
                    )
8208
8209
8210
8211
8212
8213
8214
8215
8216
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
8217
8218
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
8219
8220
8221
8222
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
8223
                        window_size=window_size,
8224
8225
8226
8227
8228
8229
8230
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
8231
                        cp_comm_type=self.cp_comm_type,
8232
8233
8234
8235
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
                    )
                return self.fused_attention(
8236
8237
8238
8239
8240
8241
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
8242
8243
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
8244
8245
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
8246
8247
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
8248
                    window_size=window_size,
8249
                    fused_attention_backend=fused_attention_backend,
8250
8251
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
8252
8253
8254
8255
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
8256
                    cp_comm_type=self.cp_comm_type,
8257
8258
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
8259
                )
8260

8261
            from .cpu_offload import CPUOffloadEnabled
8262

8263
8264
8265
8266
8267
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
8268

8269
            if use_unfused_attention:
8270
8271
8272
8273
8274
8275
                if window_size is not None and (
                    window_size[0] != -1 or window_size[1] not in [-1, 0]
                ):
                    attn_mask_type, attention_mask = get_swa_mask(
                        window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
                    )
8276
8277
8278
8279
8280
8281
8282
8283
8284
8285
8286
8287
8288
8289
8290
8291
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
8292
8293
8294
                    query_layer,
                    key_layer,
                    value_layer,
8295
8296
8297
8298
8299
8300
8301
8302
8303
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
8304

8305
            raise ValueError("No dot product attention support for the provided inputs!")
8306
8307


8308
8309
8310
8311
8312
8313
8314
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

8315
8316
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
8317

8318
8319
8320
8321
8322
8323
8324
8325
8326
8327
8328
8329
8330
8331
8332
8333
8334
8335
8336
8337
8338
8339
8340
8341
8342
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
8343
8344
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
8345
                   default = `causal`
8346
8347
8348
8349
8350
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
8351
8352
8353
8354
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
8355
8356
8357
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
8358
                be overridden by :attr:`window_size` in `forward` as well.
8359
8360
8361
8362
8363
8364
8365
8366
8367
8368
8369
8370
8371
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
8372
8373
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
8374
8375
8376
8377
8378
8379
8380
8381
8382
8383
8384
8385
8386
8387
8388
8389
8390
8391
8392
8393
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
8394
          The device on which the parameters of the model will be allocated. It is the user's
8395
8396
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
8397
8398
8399
8400
8401
8402
8403
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
8404
            For that, please use `get_qkv_layout` to gain the layout information.
8405
8406
8407
8408
8409
8410
8411
8412
8413
8414
8415
8416
8417
8418
8419
8420
8421
8422
8423
8424
8425
8426
8427
8428
8429
8430
8431
8432
8433
8434
8435
8436
8437
8438
8439
8440
8441
8442
8443
8444

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
8445
8446
8447
8448
8449
8450
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
8451
8452
8453
8454
8455
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
8456
        layer_number: Optional[int] = None,
8457
        attn_mask_type: str = "causal",
8458
        window_size: Optional[Tuple[int, int]] = None,
8459
8460
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
8461
        num_gqa_groups: Optional[int] = None,
8462
8463
8464
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
8465
        params_dtype: Optional[torch.dtype] = None,
8466
        return_bias: bool = False,
8467
8468
8469
8470
8471
8472
8473
8474
8475
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
8476
        ub_overlap_rs_dgrad: bool = False,
8477
8478
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
8479
        bias: bool = True,
8480
        normalization: str = "LayerNorm",
8481
        device: Union[torch.device, str] = "cuda",
8482
        qkv_format: str = "sbhd",
8483
8484
    ) -> None:
        super().__init__()
8485

8486
        self.qkv_format = qkv_format
8487
        self.attn_mask_type = attn_mask_type
8488
        self.window_size = check_set_window_size(attn_mask_type, window_size)
8489
        self.layer_number = layer_number
8490
8491
8492
8493
8494
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
8495
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
8496
        self.num_attention_heads = num_attention_heads
8497
8498
8499
8500
8501
8502
8503
8504
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
8505
8506
8507
8508
8509

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

8510
8511
8512
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
8513
8514
8515
8516
8517
8518

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
8519
8520
8521
8522
8523
8524
8525
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
8526
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
8527
8528
8529
8530

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
8531
8532
8533
8534
8535
8536
8537

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
8538
            "params_dtype": self.params_dtype,
8539
            "device": device,
8540
8541
8542
8543
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
8544
        if self.attention_type == "self":
8545
8546
            parameters_split = None
            if not fuse_qkv_params:
8547
8548
8549
8550
8551
8552
8553
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
8554
8555
8556
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
8557
                    self.hidden_size_q + 2 * self.hidden_size_kv,
8558
8559
8560
8561
8562
8563
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
8564
                    parameters_split=parameters_split,
8565
8566
8567
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
8568
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
8569
                    ub_overlap_ag=ub_overlap_ag,
8570
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8571
                    ub_name="qkv",
8572
8573
8574
8575
8576
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
8577
                    self.hidden_size_q + 2 * self.hidden_size_kv,
8578
8579
8580
8581
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
8582
                    parameters_split=parameters_split,
8583
8584
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
8585
        elif self.attention_type == "cross":
8586
8587
8588
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
8589
                    self.hidden_size_q,
8590
8591
8592
8593
8594
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
8595
                    parameters_split=("query",) if not fuse_qkv_params else None,
8596
8597
8598
8599
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
8600
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
8601
                    ub_overlap_ag=ub_overlap_ag,
8602
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8603
                    ub_name="qkv",
8604
8605
8606
8607
8608
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
8609
                    self.hidden_size_q,
8610
8611
8612
8613
8614
8615
8616
8617
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
8618
                2 * self.hidden_size_kv,
8619
8620
8621
8622
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
8623
                parameters_split=("key", "value") if not fuse_qkv_params else None,
8624
8625
8626
8627
8628
8629
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
8630
            self.hidden_size_per_attention_head,
8631
8632
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
8633
            qkv_format=self.qkv_format,
8634
8635
8636
8637
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
8638
            layer_number=self.layer_number,
8639
            attention_type=self.attention_type,
8640
8641
8642
8643
        )

        # Linear
        self.proj = Linear(
8644
            self.hidden_size_q,
8645
8646
8647
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
8648
            return_bias=return_bias,
8649
            parallel_mode="row" if set_parallel_mode else None,
8650
8651
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8652
            ub_name="proj",
8653
8654
8655
8656
            **common_gemm_kwargs,
        )

    def _allocate_memory(
8657
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
8658
    ) -> torch.Tensor:
8659
        """Allocates memory for KV cache."""
8660
8661
8662
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
8663
            self.num_gqa_groups_per_partition,
8664
            self.hidden_size_per_attention_head,
8665
            dtype=dtype,
8666
8667
8668
8669
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
8670
8671
8672
8673
8674
8675
8676
8677
8678
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
8679
8680
        self.tp_group = tp_group

8681
    def set_context_parallel_group(
8682
        self,
8683
        cp_group: Union[dist_group_type, List[dist_group_type], None],
8684
        cp_global_ranks: List[int],
8685
        cp_stream: torch.cuda.Stream,
8686
        cp_comm_type: str = "p2p",
8687
    ) -> None:
8688
8689
8690
8691
8692
8693
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
8694
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
8695
                  context parallel process group.
8696
8697
8698
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
8699
8700
8701
8702
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
8703
        cp_comm_type : str, default = `p2p`
8704
                      inter-gpu communication type for context parallelism.
8705
                      Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
8706
8707
8708
8709
8710
8711
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
8712
8713
8714
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
8715
        """
8716
8717
8718
8719
8720
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
8721
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
8722

8723
8724
8725
    def forward(
        self,
        hidden_states: torch.Tensor,
8726
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8727
        encoder_output: Optional[torch.Tensor] = None,
8728
        attn_mask_type: Optional[str] = None,
8729
        window_size: Optional[Tuple[int, int]] = None,
8730
8731
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
8732
        inference_params: Optional[InferenceParams] = None,
8733
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8734
8735
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
8736
        alibi_slopes: Optional[torch.Tensor] = None,
8737
8738
8739
8740
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
8741
        fast_zero_fill: bool = True,
8742
    ) -> Tuple[Union[torch.Tensor, None], ...]:
8743
8744
8745
8746
8747
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

8748
8749
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
8750
8751
8752
8753
8754

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
8755
8756
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
8757
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
8758
8759
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
8760
8761
8762
8763
8764
8765
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
8766
                       default = `None`
8767
8768
8769
8770
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
8771
8772
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
8773
8774
8775
8776
8777
8778
8779
8780
8781
8782
8783
8784
8785
8786
8787
8788
8789
8790
8791
8792
8793
8794
8795
8796
8797
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
8798
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
8799
        core_attention_bias: Optional[torch.Tensor], default = `None`
8800
8801
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
8802
8803
8804
8805
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
8806
8807
8808
8809
8810
8811
8812
8813
8814
8815
8816
8817
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
8818
8819
8820
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
8821
8822
        # hidden_states: [sq, b, h]

8823
        if attn_mask_type is None:
8824
            attn_mask_type = self.attn_mask_type
8825
8826
        if window_size is None:
            window_size = self.window_size
8827
        window_size = check_set_window_size(attn_mask_type, window_size)
8828

8829
        if "padding" in attn_mask_type and attention_mask is not None:
8830
8831
            for mask in attention_mask:
                assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
8832

8833
8834
8835
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
8836

8837
        # =================================================
8838
        # Pre-allocate memory for key-values for inference
8839
8840
8841
        # =================================================

        if inference_params and self.layer_number is not None:
8842
8843
8844
            assert (
                self.qkv_format != "thd"
            ), "qkv_format == thd is not supported for an inference with KV-cache!"
8845
            if self.layer_number not in inference_params.key_value_memory_dict:
8846
                inf_max_seq_len = inference_params.max_sequence_length
8847
8848
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
8849
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8850
8851
                )
                inference_value_memory = self._allocate_memory(
8852
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8853
8854
8855
8856
8857
8858
8859
8860
8861
8862
8863
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

8864
        # ======================
8865
        # Query, Key, and Value
8866
        # ======================
8867

8868
8869
8870
8871
8872
        fp8_mha = (
            FP8GlobalStateManager.is_fp8_enabled()
            and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
        )

8873
        layernorm_output = None
cyanguwa's avatar
cyanguwa committed
8874
8875
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
8876
8877
8878
8879
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8880
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8881
8882
8883
8884
8885
8886
8887
8888
8889
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8890
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8891
8892
                )

8893
8894
8895
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
8896
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
8897
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
8898
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
8899
8900
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
8901
8902
8903
8904
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
8905
8906
8907
8908
8909
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
8910
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
8911
8912
8913
                )
                # split along third last dimension
                split_dim = -3
8914
8915
8916

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
8917
8918
8919
8920
8921
8922
8923
8924
8925
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
8926
                )
8927
            else:
cyanguwa's avatar
cyanguwa committed
8928
                query_layer, key_layer, value_layer = torch.split(
8929
8930
8931
8932
                    mixed_x_layer,
                    (num_queries_per_key_value, 1, 1),
                    dim=split_dim,
                )
cyanguwa's avatar
cyanguwa committed
8933

8934
8935
8936
8937
8938
8939
8940
8941
8942
8943
8944
8945
            if self.qkv_format == "thd":
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
            else:
                # query: -> [sq, b, np, hn]
                # key, value: -> [sq, b, ng, hn]
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
cyanguwa's avatar
cyanguwa committed
8946
8947
        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
8948
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
8949
                encoder_output,
8950
                is_first_microbatch=is_first_microbatch,
8951
                fp8_output=fp8_mha and rotary_pos_emb is None,
8952
8953
8954
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
8955
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
8956
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
8957
                    self.num_gqa_groups_per_partition,
8958
8959
8960
8961
8962
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
8963
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
8964
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
8965
                    2 * self.num_gqa_groups_per_partition,
8966
8967
8968
8969
8970
8971
8972
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
8973
8974
8975
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
8976
8977
8978
                    mixed_kv_layer,
                    split_dim,
                    mixed_kv_layer.shape[split_dim] // 2,
cyanguwa's avatar
cyanguwa committed
8979
                )
8980
            else:
cyanguwa's avatar
cyanguwa committed
8981
                key_layer, value_layer = torch.split(
8982
8983
8984
                    mixed_kv_layer,
                    mixed_kv_layer.shape[split_dim] // 2,
                    dim=split_dim,
cyanguwa's avatar
cyanguwa committed
8985
                )
8986
8987
8988
8989
8990
8991
8992
8993
8994
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
8995
8996
8997
8998
8999
9000

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
9001
                    fp8_output=fp8_mha and rotary_pos_emb is None,
9002
9003
9004
9005
9006
9007
9008
9009
9010
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
9011
                    fp8_output=fp8_mha and rotary_pos_emb is None,
9012
9013
9014
9015
9016
9017
9018
9019
9020
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

9021
9022
9023
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
9024

9025
        if rotary_pos_emb is not None:
9026
9027
9028
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
9029
            # duplicate the pos_emb for self attention
9030
            if not isinstance(rotary_pos_emb, tuple):
9031
                rotary_pos_emb = (rotary_pos_emb,) * 2
9032
9033

            q_pos_emb, k_pos_emb = rotary_pos_emb
9034
9035
9036
9037
9038
9039
9040

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)
9041
9042
                else:
                    raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.")
9043
9044
9045
9046
9047
9048
9049

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

9050
9051
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
9052

9053
9054
9055
9056
        # ===========================
        # Core attention computation
        # ===========================

9057
9058
9059
9060
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
9061
            qkv_format=self.qkv_format,
9062
9063
9064
9065
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
9066
9067
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
9068
            window_size=window_size,
9069
9070
9071
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
9072
            alibi_slopes=alibi_slopes,
9073
            fast_zero_fill=fast_zero_fill,
9074
            inference_params=inference_params,
9075
9076
        )

9077
        # ===================
9078
        # Output. [sq, b, h]
9079
        # ===================
9080

9081
        projection_output = self.proj(
9082
9083
            context_layer,
            is_first_microbatch=is_first_microbatch,
9084
9085
        )

9086
9087
9088
9089
9090
9091
9092
9093
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
9094
        if self.input_layernorm and self.return_layernorm_output:
9095
9096
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]