attention.py 382 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
from importlib.metadata import PackageNotFoundError
10
import math
11
import os
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
import warnings
14
import logging
15

16
from dataclasses import dataclass, fields
cyanguwa's avatar
cyanguwa committed
17
import numpy as np
18
from packaging.version import Version as PkgVersion
19
20

import torch
21
import torch.nn.functional as F
22

23
import transformer_engine_torch as tex
24
25
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
26
27
28
29
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
30
31
32
33
34
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
35
36
    fused_attn_fwd,
    fused_attn_bwd,
37
38
39
40
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
41
42
43
44
45
46
47
48
49
50
51
52
53
    META_QKV,
    META_DQKV,
    META_O,
    META_DO,
    META_S,
    META_DP,
    META_O_CP,
    META_DQKV_CP,
)
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    get_fp8_te_dtype,
    get_fp8_torch_dtype,
54
)
55
from transformer_engine.pytorch.float8_tensor import Float8Tensor
56
from transformer_engine.pytorch.module import LayerNormLinear, Linear
57
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
58
59
60
61
62
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
63
    get_default_init_method,
64
65
66
67
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
68
    AttnBiasTypes,
69
    QKVLayouts,
70
    dist_group_type,
71
    TE_DType,
72
73
74
75
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
76
    get_distributed_rank,
77
    checkpoint,
78
79
80
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
81
82
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
83
84
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
85
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
86
87
from transformer_engine.pytorch.graph import is_graph_capturing

88
89
90
91
92
93
94
95
96
97
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)
98

99
100
101
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
102
103
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
104
_flash_attn_max_version = PkgVersion("2.6.3")
105
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
106
107
108
109
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
110
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
111
112
_flash_attn_3_plus = False
_use_flash_attn_3 = False
113
114
115
116
117
_flash_attn_3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
118
119
try:
    _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
120
121
    _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.9")
    _flash_attn_3_0_0_beta = _flash_attn_3_plus and _flash_attn_v3_version < PkgVersion("3.0.0")
122
except PackageNotFoundError:
123
    if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
124
125
126
127
128
129
130
131
        fa3_logger = logging.getLogger()
        fa3_logger.setLevel(_log_level)
        if not fa3_logger.hasHandlers():
            fa3_logger.addHandler(_stream_handler)
        fa3_logger.debug(
            "To use flash-attn v3, please follow these steps to install the flashattn-hopper "
            "package: \n%s",
            _flash_attn_3_installation_steps,
132
        )
133
134
135
136
137
138
139
else:
    from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
    from flashattn_hopper.flash_attn_interface import (
        flash_attn_varlen_func as flash_attn_varlen_func_v3,
    )

    _use_flash_attn_3 = True
140

141
if _flash_attn_version >= _flash_attn_version_required:
142
    from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
143
144
145
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
146

147
148
149
150
151
152
153
_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
154
}
155
156


157
158
@dataclass(eq=True)
class AttentionParams:
159
    """
160
    Attention parameters used to determine which backend to be used.
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

    Parameters
    ----------
    qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
        Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
    qkv_dtype: torch.dtype, default = `torch.bfloat16`
        Data type of query/key/value tensors.
    qkv_layout: str, default = "sbh3d"
        Query/key/value tensor memory layout.
    batch_size: int, default = 1
        Batch size.
    num_heads: int, default = 16
        Number of attention heads in the query tensor.
    num_gqa_groups: int, default = 16
        Number of attention heads in key and value tensors.
    max_seqlen_q: int, default = 128
        Maximum sequence length of the query tensor.
    max_seqlen_kv: int, default = 128
        Maximum sequence length of the key and value tensors.
180
181
182
183
    head_dim_qk: int, default = 64
        The size of each attention head in query and key tensors.
    head_dim_v: int, default = 64
        The size of each attention head in the value tensor.
184
185
186
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
        `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
187
    window_size: Tuple[int, int], default = None
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        Sliding window attention size.
    alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
        Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
    core_attention_bias_type: str, default = `no_bias`
        Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
    core_attention_bias_shape: str, default = `1hss`
        Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
    core_attention_bias_requires_grad: bool, default = `True`
        Whether attention bias requires gradient.
    pad_between_seqs: bool, default = `False`
        Whether there is padding between sequences in a batch.
        This only applies to `qkv_format=thd`.
    attention_dropout: float, default = 0.0
        Attention dropout.
    context_parallel: bool, default = `False`
        Whether context parallelism is used or not.
    deterministic: bool, default = `False`
        Whether to run `DotProductAttention` with determinism or not.
206
207
    is_training: bool, default = `True`
        Whether in training mode (`True`) or inference mode (`False`)
208
209
210
211
    fp8: bool, default = `False`
        Whether `DotProductAttention` is in an `fp8_autocast` region.
    fp8_meta: Optional[Dict[str Any]], default = `None`
        The FP8 metadata tensor of `DotProductAttention`.
212
213
214
215
216
217
218
219
220
221
    """

    qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
    qkv_dtype: torch.dtype = torch.bfloat16
    qkv_layout: str = "sbh3d"
    batch_size: int = 1
    num_heads: int = 16
    num_gqa_groups: int = 16
    max_seqlen_q: int = 128
    max_seqlen_kv: int = 128
222
223
    head_dim_qk: int = 64
    head_dim_v: int = 64
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    attn_mask_type: str = "no_mask"
    window_size: Union[Tuple[int, int], None] = None
    alibi_slopes_shape: Union[torch.Size, List, None] = None
    core_attention_bias_type: str = "no_bias"
    core_attention_bias_shape: str = "1hss"
    core_attention_bias_requires_grad: bool = True
    pad_between_seqs: bool = False
    attention_dropout: float = 0.0
    context_parallel: bool = False
    deterministic: bool = False
    is_training: bool = True
    fp8: bool = False
    fp8_meta: Union[Dict[str, Any], None] = None


_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


def get_attention_backend(
    attention_params: AttentionParams = None,
):
    """
    Select the appropriate attention backend/sub-backend based on user input and runtime environment.

    Parameters
    ----------
    See `AttentionParams`.
263
264
265
266
267
268
269

    Returns
    ----------
    use_flash_attention: bool
        Whether the `FlashAttention` backend has been selected.
    use_fused_attention: bool
        Whether the `FusedAttention` backend has been selected.
270
271
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend
        If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
272
273
274
275
276
277
    use_unfused_attention: bool
        Whether the `UnfusedDotProductAttention` backend has been selected.
    available_backends: List[bool]
        All available backends that could support the provided input. A list of Booleans
        in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
    """
278
279
280
281
282
283
284
285
    qkv_type = attention_params.qkv_type
    qkv_dtype = attention_params.qkv_dtype
    qkv_layout = attention_params.qkv_layout
    batch_size = attention_params.batch_size
    num_heads = attention_params.num_heads
    num_gqa_groups = attention_params.num_gqa_groups
    max_seqlen_q = attention_params.max_seqlen_q
    max_seqlen_kv = attention_params.max_seqlen_kv
286
287
    head_dim_qk = attention_params.head_dim_qk
    head_dim_v = attention_params.head_dim_v
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    attn_mask_type = attention_params.attn_mask_type
    window_size = attention_params.window_size
    alibi_slopes_shape = attention_params.alibi_slopes_shape
    core_attention_bias_type = attention_params.core_attention_bias_type
    core_attention_bias_shape = attention_params.core_attention_bias_shape
    core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
    pad_between_seqs = attention_params.pad_between_seqs
    attention_dropout = attention_params.attention_dropout
    context_parallel = attention_params.context_parallel
    deterministic = attention_params.deterministic
    is_training = attention_params.is_training
    fp8 = attention_params.fp8
    fp8_meta = attention_params.fp8_meta

    # Run config
303
    logger = logging.getLogger("DotProductAttention")
304
305
306
    logger.setLevel(_log_level)
    if not logger.hasHandlers():
        logger.addHandler(_stream_handler)
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    device_compute_capability = get_device_compute_capability()
    cudnn_version = get_cudnn_version()
    run_config = {
        "transformer_engine_version": te.__version__,
        "compute_capability": "sm"
        + str(
            (lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
        ),
        "flash_attn_version": _flash_attn_version,
        "cudnn_version": ".".join([str(i) for i in cudnn_version]),
    }
    attention_params_dict = {
        field.name: getattr(attention_params, field.name) for field in fields(attention_params)
    }
    run_config.update(attention_params_dict)
    if fp8:
        run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
    logger.debug("Running with config=%s", run_config)
325
326

    # Filter: Environment variables
327
328
329
330
331
332
333
    global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN
    _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
    _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
    _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
    use_flash_attention = _NVTE_FLASH_ATTN
    use_fused_attention = _NVTE_FUSED_ATTN
    use_unfused_attention = _NVTE_UNFUSED_ATTN
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    if not use_flash_attention:
        logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
    if not use_fused_attention:
        logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
    if not use_unfused_attention:
        logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")

    # Filter: ONNX mode
    if is_in_onnx_export_mode():
        if use_flash_attention:
            logger.debug("Disabling FlashAttention due to ONNX mode")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention due to ONNX mode")
        use_fused_attention = False

    # Filter: Compute capability
351
    global _use_flash_attn_3
352
353
354
355
356
357
358
    if device_compute_capability < (8, 0):
        if use_flash_attention:
            logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
            use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
            use_fused_attention = False
359
    if device_compute_capability < (9, 0):
360
        if use_flash_attention and _use_flash_attn_3:
361
362
            logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
            _use_flash_attn_3 = False
363
364

    # Filter: Data type
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
        torch.Tensor,
        Float8Tensor,
    ]:
        if use_flash_attention:
            logger.debug(
                "Disabling FlashAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
            use_flash_attention = False
        if use_fused_attention:
            logger.debug(
                "Disabling FusedAttention due to unsupported QKV data type. "
                "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
                "Found: qkv_dtype = %s.",
                qkv_dtype,
            )
            use_fused_attention = False
385
386
387

    # Filter: Execution type
    if fp8 and fp8_meta["recipe"].fp8_dpa:
388
389
390
391
392
393
394
        if use_flash_attention and not _use_flash_attn_3:
            logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
            use_flash_attention = False
        if use_flash_attention and _use_flash_attn_3 and is_training:
            logger.debug(
                "Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
            )
395
396
397
398
399
400
            use_flash_attention = False
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
            use_unfused_attention = False

    # Filter: Head dimension
401
402
403
    if use_flash_attention and head_dim_qk != head_dim_v:
        logger.debug("Disabling FlashAttention as it does not support MLA.")
        use_flash_attention = False
404
    if use_flash_attention and (
405
406
407
        head_dim_qk > 256
        or head_dim_qk % 8 != 0
        or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
408
409
    ):
        logger.debug(
410
411
412
413
414
415
            "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
            "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
            "head_dim_qk <= 256 (>192 requires sm80/90). "
            "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
            head_dim_qk,
            head_dim_v,
416
417
418
            ".".join([str(i) for i in device_compute_capability]),
        )
        use_flash_attention = False
419
420
421
422
423
424
425
    qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
    if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
        logger.debug(
            "Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
            qkv_layout,
        )
        use_fused_attention = False
426
427
428
429
430
431
432
433
434
435
436
437
438
439

    # Filter: QKV layout
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
            use_unfused_attention = False
        if use_flash_attention and pad_between_seqs:
            logger.debug(
                "Disabling FlashAttention for qkv_format = thd when there is "
                "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
            )
            use_flash_attention = False

440
    # Filter: Dropout
441
442
443
    if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
        logger.debug("Disabling FlashAttention 3 for dropout")
        _use_flash_attn_3 = False
444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    # Filter: Context parallelism
    # qkv_format | attn_mask_type              | attn_bias_type           | supported backends
    # ----------------------------------------------------------------------------------------------------
    # bshd, sbhd | self-attention:             | no_bias, post_scale_bias | FlashAttention, FusedAttention
    #            |     no_mask, causal         |                          |
    #            | cross-attention:            |                          |
    #            |     no_mask                 |                          |
    # thd        | self-attention:             | no_bias                  | FlashAttention, FusedAttention
    #            |     padding, padding_causal |                          | if no padding between sequences,
    #            | cross-attention:            |                          | FusedAttention
    #            |     padding                 |                          | if there is padding between sequences
    # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v.
    if context_parallel and use_unfused_attention:
        logger.debug(
            "Disabling UnfusedDotProductAttention as it does not support context parallelism"
        )
        use_unfused_attention = False
    if context_parallel and use_flash_attention:
463
        if _use_flash_attn_3:
464
465
466
467
468
469
470
            logger.debug("Disabling FlashAttention 3 for context parallelism")
            _use_flash_attn_3 = False
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with FP8"
            )
            use_flash_attention = False
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        if "bottom_right" in attn_mask_type:
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with"
                " causal_bottom_right masking"
            )
            use_flash_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with causal"
                " masking for cross-attention"
            )
            use_flash_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with bias type"
                " of %s",
                core_attention_bias_type,
            )
            use_flash_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
            logger.debug(
                "Disabling FlashAttention as it does not support context parallelism with attention"
                " bias for THD format"
            )
            use_flash_attention = False
496

497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    if context_parallel and use_fused_attention:
        if "bottom_right" in attn_mask_type:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with"
                " causal_bottom_right masking"
            )
            use_fused_attention = False
        elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with causal"
                " masking for cross-attention"
            )
            use_fused_attention = False
        elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with bias type"
                " of %s",
                core_attention_bias_type,
            )
            use_fused_attention = False
        elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with attention"
                " bias for THD format"
            )
            use_fused_attention = False
        elif head_dim_qk != head_dim_v:
            logger.debug(
                "Disabling FusedAttention as it does not support context parallelism with MLA"
            )
            use_fused_attention = False

529
    # Filter: Attention mask
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
    # attn_mask_type              | attention_mask                       | supported backends
    # ----------------------------------------------------------------------------------------
    # no_mask                     | None                                 | All
    # padding                     |                                      | All
    #     self-attention          | One tensor in shape [b, 1, 1, sq]    |
    #     cross-attention         | Tuple of two tensors in shapes       |
    #                             | [b, 1, 1, sq] and [b, 1, 1, skv]     |
    # causal                      | None                                 |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # padding_causal              | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FusedAttention, UnfusedDotProductAttention
    # causal_bottom_right         | None                                 | All
    # padding_causal_bottom_right | Same as "padding"                    |
    #     self-attention          |                                      | All
    #     cross-attention         |                                      | FlashAttention, UnfusedDotProductAttention
    # arbitrary                   | One tensor in shape broadcastable to | UnfusedDotProductAttention
    #                             | [b, h, sq, skv]                      |
549
550
551
552
553
554
555
    if attn_mask_type == "arbitrary":
        if use_flash_attention:
            logger.debug("Disabling FlashAttention for arbitrary mask")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention for arbitrary mask")
        use_fused_attention = False
556
557
    if (
        use_flash_attention
558
        and _use_flash_attn_3
559
560
561
562
563
564
565
566
567
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        _use_flash_attn_3 = False
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    if (
        use_flash_attention
        and _flash_attn_2_1_plus
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        use_flash_attention = False
    if (
        use_flash_attention
        and not _flash_attn_2_1_plus
        and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention as it only supports top-left-diagonal "
            "causal mask before flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        use_flash_attention = False
592
593
594
595
596
597
598
599
600
    if (
        use_flash_attention
        and _use_flash_attn_3
        and fp8
        and fp8_meta["recipe"].fp8_dpa
        and "padding" in attn_mask_type
    ):
        logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
        _use_flash_attn_3 = False
601
602

    # Filter: Sliding window attention
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
    #    backend                 |      window_size       | diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | (-1, -1) or (>=0, >=0) | bottom right
    # FusedAttention             | (-1,  0) or (>=0, 0)   | top left
    # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
    #                            |                        | converts window_size to an 'arbitrary' mask
    if window_size is None:
        window_size = check_set_window_size(attn_mask_type, window_size)
    else:
        if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention"
                    " for FP8"
                )
                use_fused_attention = False
            elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
                logger.debug(
                    "Disabling FusedAttention as it only supports sliding window attention "
                    "with causal mask, no dropout, and qkv_format = bshd/sbhd"
                )
                use_fused_attention = False
            elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
                "no_mask",
                "padding",
                "causal_bottom_right",
                "padding_causal_bottom_right",
            ]:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s for cross-attention",
                    attn_mask_type,
                )
                use_fused_attention = False
            elif "padding" in attn_mask_type:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s",
                    attn_mask_type,
                )
                use_fused_attention = False
        if (
            use_flash_attention
            and (window_size[0] != -1 or window_size[1] not in [-1, 0])
647
            and not _flash_attn_2_3_plus
648
        ):
649
            logger.debug(
650
                "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
651
652
653
654
            )
            use_flash_attention = False

    # Filter: Attention bias
655
656
657
658
659
660
661
662
    #    backend                 |      bias types              | ALiBi diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | no_bias, alibi/alibi_slopes  | bottom right
    # FusedAttention             | no_bias, post_scale_bias     |
    #                            | alibi/alibi_slopes           | top left,
    #                            |                              | bottom_right (converts to a 'post_scale_bias' bias)
    # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
    #                            | alibi/alibi_slopes           | both; converts to a 'post_scale_bias' bias
663
    if use_flash_attention and core_attention_bias_type == "alibi":
664
        if _use_flash_attn_3:
665
666
            logger.debug("Disabling FlashAttention 3 for ALiBi")
            _use_flash_attn_3 = False
667
668
669
        if not _use_flash_attn_3 and not _flash_attn_2_4_plus:
            logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
            use_flash_attention = False
670

671
672
673
674
675
676
677
678
679
680
681
682
683
    if use_flash_attention and (
        core_attention_bias_type not in ["no_bias", "alibi"]
        or core_attention_bias_shape is not None
    ):
        logger.debug("Disabling FlashAttention for pre/post_scale_bias")
        use_flash_attention = False

    fu_core_attention_bias_type = core_attention_bias_type
    fu_core_attention_bias_shape = core_attention_bias_shape
    fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
    if (
        use_fused_attention
        and core_attention_bias_type == "alibi"
684
        and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
685
686
687
    ):
        fu_core_attention_bias_type = "post_scale_bias"
        fu_core_attention_bias_requires_grad = False
688
689
690
691
692
        if alibi_slopes_shape is None:
            fu_core_attention_bias_shape = "1hss"
        elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
            fu_core_attention_bias_shape = "1hss"
        elif (
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
            len(alibi_slopes_shape) == 2
            and alibi_slopes_shape[0] == batch_size
            and alibi_slopes_shape[1] == num_heads
        ):
            fu_core_attention_bias_shape = "bhss"

    if (
        use_fused_attention
        and fu_core_attention_bias_type == "post_scale_bias"
        and fu_core_attention_bias_shape != "1hss"
    ):
        if fu_core_attention_bias_requires_grad:
            # remove this line when cuDNN adds bwd support for
            # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
            logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
            use_fused_attention = False
        else:
            # max512 backend will only support [1, h, s, s]
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

    # Filter: cuDNN support
    fused_attention_backend = None
    if use_fused_attention:
        q_type = TE_DType[qkv_dtype]
        kv_type = q_type
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            kv_type = q_type
        fused_attention_backend = tex.get_fused_attn_backend(
            q_type,
            kv_type,
            QKVLayout[qkv_layout],
            AttnBiasType[fu_core_attention_bias_type],
            AttnMaskType[attn_mask_type],
            attention_dropout,
            num_heads,
            num_gqa_groups,
            max_seqlen_q,
            max_seqlen_kv,
732
733
            head_dim_qk,
            head_dim_v,
734
735
            window_size[0],
            window_size[1],
736
        )
737
        if fused_attention_backend == FusedAttnBackend["No_Backend"]:
738
739
            logger.debug("Disabling FusedAttention as no backend supports the provided input")
            use_fused_attention = False
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
            fused_attention_backend = None
        if (
            use_fused_attention
            and window_size is not None
            and window_size[0] != -1
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "slidng window attention",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
757
758
759
760
761
762
763
764
            and fu_core_attention_bias_type == "post_scale_bias"
            and fu_core_attention_bias_shape != "1hss"
        ):
            logger.debug(
                "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
                " [1, H, S, S] shape"
            )
            use_fused_attention = False
765
            fused_attention_backend = None
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785

    # Filter: Determinism
    # backend                      | deterministic
    # ---------------------------------------------
    # FlashAttention               |
    #     flash-attn >=2.0, <2.4.1 | no
    #     flash-attn >=2.4.1       | yes
    # FusedAttention               |
    #     sub-backend 0            | yes
    #     sub-backend 1            | workspace optimization path and sm90+: yes;
    #                              | otherwise: no
    #     sub-backend 2            | no
    # UnfusedDotProductAttention   | yes
    if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus:
        logger.warning(
            "Disabling FlashAttention as version <2.4.1 does not support deterministic "
            "execution. To use FlashAttention with deterministic behavior, "
            "please install flash-attn >= 2.4.1."
        )
        use_flash_attention = False
786
787
788
789
790
791
792
793
794
795
796
    if use_fused_attention and deterministic:
        if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (
                device_compute_capability < (9, 0)
                or core_attention_bias_requires_grad
                or cudnn_version < (8, 9, 5)
797
            )
798
799
800
        ):
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
801
802
803

    # All available backends
    available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
804
805
806
807
808
809
810
811
812
813
814
815
    logger.debug(
        "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
        " UnfusedDotProductAttention=%s}",
        bool(available_backends[0]),
        bool(available_backends[1]),
        (
            f" (sub-backend {int(fused_attention_backend)})"
            if fused_attention_backend is not None
            else ""
        ),
        bool(available_backends[2]),
    )
816
817
818
819
820
821
822
823
824
825
826
827
828

    # Select FusedAttention for performance
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
    ):
        if device_compute_capability == (9, 0):
            logger.debug(
                "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                "for performance reasons"
            )
            use_flash_attention = False
829
830
831
832
833
834
835
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["FP8"]
        and _use_flash_attn_3
    ):
        logger.debug(
836
837
            "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
            "in FP8 execution"
838
839
840
        )
        use_flash_attention = False

841
842
843
844
845
846
    # Selected backend
    if use_flash_attention:
        use_fused_attention = False
        use_unfused_attention = False
    elif use_fused_attention:
        use_unfused_attention = False
847
    selected_backend = "NoBackend"
848
849
850
851
852
853
    if use_flash_attention:
        selected_backend = "FlashAttention"
    elif use_fused_attention:
        selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
    elif use_unfused_attention:
        selected_backend = "UnfusedDotProductAttention"
854
    logger.debug("Selected backend = %s", selected_backend)
855

856
857
858
859
860
861
    global _attention_backends
    _attention_backends["use_flash_attention"] = use_flash_attention
    _attention_backends["use_fused_attention"] = use_fused_attention
    _attention_backends["fused_attention_backend"] = fused_attention_backend
    _attention_backends["use_unfused_attention"] = use_unfused_attention
    _attention_backends["backend_selection_requires_update"] = False
862
863
864
865

    return (
        use_flash_attention,
        use_fused_attention,
866
        fused_attention_backend,
867
868
869
870
871
        use_unfused_attention,
        available_backends,
    )


872
class InferenceParams:  # pylint: disable=too-few-public-methods
873
874
    """
    Inference parameters that are passed to the main model in order
875
    to efficiently calculate and store the context during inference.
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
916

917

918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
@torch.no_grad()
def get_swa_mask(
    window_size: Tuple[int, int],
    max_seqlen_q: int,
    max_seqlen_kv: int,
    attn_mask_type: str = "no_mask",
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
    """
    Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
    For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
    and for other mask types, the bottom right corner.

    Parameters
    ----------
    window_size: Tuple[int, int]
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
        window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
        map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
        `attn_mask_type`.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
        "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
        default = `None`
        Boolean tensor(s) used to mask out attention softmax input.

    Returns
    ----------
    attention_mask: torch.Tensor
        Combined `attention_mask` (input) and sliding window attention mask.
        The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
        else, the same shape as input `attention_mask`.
    """
    mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
    if attn_mask_type in ["causal"]:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_q
        right = window_size[1] if window_size[1] != -1 else max_seqlen_q
        mask_upper = torch.triu(mask, diagonal=-left)
        mask_lower = torch.tril(mask_upper, diagonal=right)
    else:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
        right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
        mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
        mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
    attn_mask_type = "arbitrary"
    mask = mask_lower.logical_not()
    if attention_mask is not None:
        mask = torch.logical_and(attention_mask, mask)
    return attn_mask_type, mask


976
977
978
979
980
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
981
982
    actual_seqlens_q: Optional[torch.Tensor] = None,
    actual_seqlens_kv: Optional[torch.Tensor] = None,
983
984
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
985
    bottom_right_alignment: bool = True,
986
) -> Tuple[torch.Tensor, torch.Tensor]:
987
    """
988
989
990
991
992
993
994
995
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
996
997
998
999
    actual_seqlens_q: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for queries, in shape [batch_size].
    actual_seqlens_kv: Optional[torch.Tensor], default = `None`
        Actual sequence lengths for keys and values, in shape [batch_size].
1000
1001
1002
1003
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
1004
1005
1006
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the ALiBi bias to the bottom right corner of
        the matrix (`True`) or top left (`False`).
1007

1008
1009
1010
1011
1012
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
1013
1014
1015
1016
1017
1018
        ALiBi bias in FP32 or `bias_dtype`. Its shape is
        (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
        and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
        (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
        [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
        `actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
        if _alibi_cache["_alibi_slopes"].dim() == 2:
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
1044
        bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1045
            1, 1, max_seqlen_q, 1
1046
1047
        ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
            1, 1, 1, max_seqlen_kv
1048
        )
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        if actual_seqlens_q is None and actual_seqlens_kv is None:
            if bottom_right_alignment:
                bias = bias + max_seqlen_kv - max_seqlen_q
        elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
            batch_size = actual_seqlens_q.shape[0]
            bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
            if bottom_right_alignment:
                bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
        else:
            assert (
                False
            ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
1061
1062
1063
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
1064
        _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
1065
1066
1067
1068
1069
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
1070
1071
1072
1073
1074
1075
1076
1077
1078


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
1079
    reduced_mask = mask.logical_not().sum(dim=1)
1080
1081
1082
1083
1084
1085
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

1086

1087
1088
1089
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
1090
1091
1092
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
1093
1094
1095
1096
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

1097
    reduced_mask = mask.logical_not().sum(dim=1)
1098
1099
1100
1101
1102
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
1103
    indices = mask.logical_not().nonzero()
1104
1105
1106
1107
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
1108
1109
1110
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
1111
1112
1113
1114

    return cu_seqlens, indices


1115
1116
1117
1118
1119
1120
1121
1122
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
1123
1124
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
1125
1126
1127

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
1128
1129
1130
1131
1132
1133
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
1134
1135
1136

    return indices

1137

1138
_cu_seqlens_cache = {}
1139
1140


1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
1161
1162


1163
@torch.compile
1164
1165
1166
1167
1168
1169
1170
1171
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
1172
1173
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1174
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
1175
1176
1177
1178
1179
1180
1181
1182
    if isinstance(tensor, Float8Tensor):
        tensor_data = torch.cat((tensor._data, padding_indice), dim=0)

        packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices))
    else:
        tensor = torch.cat((tensor, padding_indice), dim=0)

        packed = torch.gather(tensor, 0, indices)
1183
1184
1185
    return packed


1186
@torch.compile
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


1200
@torch.compile
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


1216
@torch.compile
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
1227
1228
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1229
1230
1231
1232
1233
1234
    if isinstance(tensor, Float8Tensor):
        unpacked.scatter_(0, indices, tensor._data)
        unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :])
    else:
        unpacked.scatter_(0, indices, tensor)
        unpacked = unpacked[0:-1, :, :]
1235
1236
1237
    return unpacked


1238
@torch.compile
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


1253
@torch.compile
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
1274

1275
1276
    @staticmethod
    def forward(
1277
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
1278
1279
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
1280
        ctx.save_for_backward(indices)
1281
1282
1283
1284
1285
1286
1287
1288
1289
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
1290
        (indices,) = ctx.saved_tensors
1291
        if len(grad_outputs) == 1:
1292
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
1293
        if len(grad_outputs) == 2:
1294
1295
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
1296
1297
1298
1299
1300
1301


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
1302

1303
1304
1305
1306
1307
1308
1309
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
1310
        ctx.save_for_backward(indices)
1311
1312
1313
1314
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
1315
1316
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
1317
1318


1319
1320
1321
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
1322
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
1323
1324
1325
1326
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
1327
1328
1329
1330
1331
1332
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
1333
1334
1335
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
1336
1337
1338
1339
1340
1341
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


1361
@jit_fuser
1362
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
1363
    """Merge partial outputs of each step in Attention with context parallelism"""
1364
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
1365
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
1366
    out_corrected = out_per_step * softmax_lse_corrected_exp
1367
1368
1369
    out.add_(out_corrected)


1370
@jit_fuser
1371
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
1372
    """Merge softmax stats of each step in Attention with context parallelism"""
1373
1374
1375
1376
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
1377
1378


1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
@jit_fuser
def get_cu_seqlens_on_cp_rank(
    cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
@torch.compile
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
    To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
    before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
    sequence chunk ids for reordering.
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
    if to_contiguous:
        for rank in range(cp_size):
            chunk_ids[rank] = 2 * rank
            chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
    else:
        for rank in range(cp_size):
            chunk_ids[2 * rank] = rank
            chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
    return chunk_ids


@torch.compile
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
    """Reorder sequence chunk for A2A communication."""
    if before_attn:
        # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
        # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
        x = x.movedim(0, seq_dim).contiguous()
        # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
        # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
        x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
        # reorder the sequence chunks
        x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
    else:
        # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
        # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
        x = x.movedim(seq_dim, 0).contiguous()
        # reorder the sequence chunks
        x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
        # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
        # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
        x = x.view(cp_size, 2, *x.shape[1:])
    return x


def flash_attn_a2a_communicate(
    a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
    chunk_ids_for_a2a: torch.Tensor,
    seq_dim: int,
    cp_size: int,
    cp_group: dist_group_type,
    cp_stream: torch.cuda.Stream,
    before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
    """A2A communication for context parallelism."""
    a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
    a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
    if before_attn:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # reorder the sequence chunks
                    x = reorder_seq_chunks_for_a2a(
                        x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
                    )
                    # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
                    # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
                    a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, s, np, hn] -> [b, s, cp, np//cp, hn]
                # or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
                x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
                # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
                # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
                a2a_inputs[i] = x.movedim(-3, 0).contiguous()
    else:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
                # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
                x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
                # reorder the sequence chunks
                a2a_inputs[i] = reorder_seq_chunks_for_a2a(
                    x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
                    # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
                    x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
                    # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
                    # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
                    a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
    torch.cuda.current_stream().wait_stream(cp_stream)
    return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs


1512
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
1513
    """
1514
1515
1516
    Attention implementation with context parallelism. Exchange KV between CP ranks
    with P2P in ring topology. Split attention compute into multiple steps, and overlap
    current-step compute with next-step communication.
1517
1518
1519
1520
1521

    This implementation also supports hierarchical CP, which parallelizes attention
    heads in low-level CP groups and parallelizes sequence dimension in high-level CP
    groups. For more details, please refer to `LongVILA <https://arxiv.org/abs/2408.10188>`_
    and `USP <https://arxiv.org/abs/2405.07719>`_.
1522
1523
1524
    """

    @staticmethod
1525
1526
1527
1528
1529
1530
1531
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
1532
        cu_seqlens_kv,
1533
        max_seqlen_q,
1534
        max_seqlen_kv,
1535
1536
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
1537
1538
1539
1540
1541
1542
1543
1544
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
1545
1546
        fp8,
        fp8_meta,
1547
1548
1549
        cp_group,
        cp_global_ranks,
        cp_stream,
1550
    ):
1551
1552
1553
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
        if isinstance(cp_group, list):
            assert (
                qkv_format != "thd"
            ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
            assert attn_bias_type == "no_bias", (
                f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
                " yet!"
            )
            cp_group_a2a = cp_group[0]
            cp_size_a2a = get_distributed_world_size(cp_group_a2a)
            rank_a2a = get_distributed_rank(cp_group_a2a)
            cp_group = cp_group[1]
        else:
            cp_group_a2a = None
            cp_size_a2a = 1
            rank_a2a = 0

1571
1572
        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
1573
1574
        send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
1575
1576
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1577
1578
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
1579

1580
        if qkv_format in ["bshd", "sbhd"]:
1581
            seq_dim = qkv_format.index("s")
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
        pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
        cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
        cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
1594

1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
        if fp8:
            if use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_backend = FusedAttnBackend["FP8"]
                if fp8_meta["recipe"].fp8_mha:
                    assert (
                        isinstance(q, Float8Tensor)
                        and isinstance(k, Float8Tensor)
                        and isinstance(v, Float8Tensor)
                    ), "q/k/v must be Float8Tensors for FP8 MHA!"
                    fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                else:
                    q_f16, k_f16, v_f16 = q, k, v
                    if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                        q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                    if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                        k, v = [
                            cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                            for x in [k_f16, v_f16]
                        ]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
                fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            q_f16 = q
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True)
            q, k, v = flash_attn_a2a_communicate(
                [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
            )
            if not fp8:
                q_f16 = q
            elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                q_f16 = q
                q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)

1648
1649
1650
        assert qkv_format == "thd" or (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"
1651
        if causal:
1652
1653
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
1654
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
1655
1656
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
1657
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
1658
1659
1660
        total_tokens_kv = None if qkv_format != "thd" else k.shape[0]
        # remove padded tokens at the end
        k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]]
1661
        if attn_bias is not None:
1662
            assert len(attn_bias.shape) == 4, (
1663
1664
1665
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
1666
1667
1668
            assert (
                attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
            ), "Sequence length does not meet divisible requirements!"
1669
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
1670
1671
1672
1673
1674
1675
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
1676
1677
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
1678
1679
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
1680
            )
1681
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
1682
1683
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
1684
            fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
1685
1686
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
1687
1688
        if _flash_attn_2_5_7_plus:
            fa_optional_forward_kwargs["block_table"] = None
1689

1690
1691
1692
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
1693
        attn_bias_inputs = [None, None]
1694
1695
1696
1697
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
1698
        attn_biases = [None for _ in range(cp_size)]
1699
1700
1701
1702
1703
1704
1705

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
1706
1707
1708
1709
        if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
1710
1711
        send_recv_reqs = [[], []]

1712
        for i in range(cp_size + 1):
1713
            if i < cp_size:
1714
                with torch.cuda.stream(flash_attn_streams[i % 2]):
1715
                    # wait until KV is received
1716
                    for req in send_recv_reqs[(i + 1) % 2]:
1717
1718
                        req.wait()

1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
                    if (
                        not fp8
                        or fp8_meta["recipe"].fp8_mha
                        or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
                    ):
                        kv_inputs[i % 2] = p2p_comm_buffers[i]
                    else:
                        # KV exchange is in BF16/FP16, cast received KV in each step
                        kv_inputs[i % 2] = cast_to_fp8(
                            p2p_comm_buffers[i],
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                        )
                    if fp8 and use_fused_attention:
1746
1747
1748
1749
                        fp8_meta_kwargs["amax_s"] = amax_per_step
                        fp8_meta_kwargs["amax_s_offset"] = i
                        fp8_meta_kwargs["amax_o"] = amax_per_step
                        fp8_meta_kwargs["amax_o_offset"] = cp_size + i
1750
1751
                    if causal:
                        if i == 0:
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1764
                            if use_fused_attention:
1765
1766
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1767
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1768
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
1769
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
1770
                                        k.shape[0], -1, 2, *k.shape[-2:]
1771
                                    )
1772
1773
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1774
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1775
1776
1777
1778
                                    # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        -1, k.shape[2], 2, *k.shape[-2:]
                                    )
1779
                                elif qkv_format == "thd":
1780
                                    q_inputs[i % 2] = q
1781
1782
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1783
1784
1785
1786
1787
1788
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1789
                                    ).contiguous()
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
1818
                                )
1819
1820
1821
1822
1823
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1824
1825
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1826
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1827
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1842
1843
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1844
                                    max_seqlen_q,
1845
                                    max_seqlen_kv,
1846
1847
1848
1849
1850
                                    dropout_p,
                                    softmax_scale,
                                    causal=True,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1851
                                )
1852
                        elif i <= rank:
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
1870
                            if use_fused_attention:
1871
1872
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1873
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1874
1875
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
1876
1877
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1878
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1879
1880
                                    # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous()
1881
                                elif qkv_format == "thd":
1882
                                    q_inputs[i % 2] = q
1883
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
1884
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
1885
                                        kv_inputs[i % 2], cu_seqlens_kv_padded, 0
1886
                                    )
1887
1888
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1889
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv // 2,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=(
                                        None
                                        if cu_seqlens_kv_padded is None
                                        else cu_seqlens_kv_padded // 2
                                    ),
                                    **fp8_meta_kwargs,
1922
                                )
1923
1924
1925
1926
1927
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1928
1929
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1930
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1931
1932
                                if qkv_format == "thd":
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
1933
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
1934
                                        kv_inputs[i % 2], cu_seqlens_kv_padded, 0
1935
                                    )
1936
1937
                                else:
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
1938
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
1939
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
1940
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
1941
                                if _flash_attn_2_3_plus:
1942
                                    fa_optional_forward_kwargs["window_size"] = (-1, -1)
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1956
1957
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1958
                                    max_seqlen_q,
1959
                                    max_seqlen_kv // 2,
1960
1961
1962
1963
1964
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1965
1966
                                )
                        else:
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1984
                            if use_fused_attention:
1985
1986
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
1987
                                    q_inputs[i % 2] = q[:, 1, ...].contiguous()
1988
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
1989
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
1990
                                        k.shape[0], -1, 2, *k.shape[-2:]
1991
                                    )
1992
1993
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
1994
                                    q_inputs[i % 2] = q[1].contiguous()
1995
1996
1997
1998
                                    # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        -1, k.shape[2], 2, *k.shape[-2:]
                                    )
1999
2000
                                elif qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
2001
2002
2003
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(
                                        q, cu_seqlens_q_padded, 1
                                    )
2004
2005
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
2006
2007
2008
2009
2010
2011
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
2012
                                    ).contiguous()
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q // 2,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    q_inputs[i % 2],
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    fused_attn_qkv_dtype,
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=(
                                        None
                                        if cu_seqlens_q_padded is None
                                        else cu_seqlens_q_padded // 2
                                    ),
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
2045
                                )
2046
2047
2048
2049
2050
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
2051
                            else:
2052
2053
                                if qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
2054
2055
2056
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(
                                        q, cu_seqlens_q_padded, 1
                                    )
2057
2058
                                else:
                                    # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
2059
                                    q_inputs[i % 2] = (
2060
                                        q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
2061
                                    )
2062
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
2063
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
2064
                                if _flash_attn_2_3_plus:
2065
                                    fa_optional_forward_kwargs["window_size"] = (-1, -1)
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
2079
2080
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
2081
                                    max_seqlen_q // 2,
2082
                                    max_seqlen_kv,
2083
2084
2085
2086
2087
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
2088
2089
                                )
                    else:
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
                        if pad_between_seqs_q:
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
                        else:
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                        if pad_between_seqs_kv:
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
                        else:
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
2107
                        if use_fused_attention:
2108
2109
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
2110
2111
2112
2113
2114
2115
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
2116
                                ).contiguous()
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
                            out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                is_training,
                                max_seqlen_q,
                                max_seqlen_kv,
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                fused_attn_qkv_dtype,
                                fused_attn_backend,
                                attn_scale=softmax_scale,
                                dropout=dropout_p,
                                qkv_layout=qkv_layout,
                                attn_mask_type=attn_mask_type,
                                attn_bias_type=attn_bias_type,
                                attn_bias=attn_bias_inputs[i % 2],
                                cu_seqlens_q_padded=cu_seqlens_q_padded,
                                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                **fp8_meta_kwargs,
2145
                            )
2146
2147
2148
2149
2150
                            if fp8:
                                softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                            else:
                                softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                attn_biases[i] = rest[0] if len(rest) > 0 else None
2151
                        else:
2152
                            # [b, sq, np, hn] -> [b*sq, np, hn]
2153
                            q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
2154
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
                            kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                            (
                                _,
                                _,
                                _,
                                _,
                                out_per_step[i],
                                softmax_lse_per_step[i],
                                _,
                                rng_states[i],
                            ) = _flash_attn_forward(
                                q_inputs[i % 2],
                                kv_inputs[i % 2][0],
                                kv_inputs[i % 2][1],
2169
2170
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
2171
                                max_seqlen_q,
2172
                                max_seqlen_kv,
2173
2174
2175
2176
2177
                                dropout_p,
                                softmax_scale,
                                causal=False,
                                return_softmax=False,
                                **fa_optional_forward_kwargs,
2178
                            )
2179
2180
2181
2182

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
2183
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
2184

2185
2186
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
2187
                    softmax_lse_per_step[i - 1].squeeze_(-1)
2188

2189
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
2190
2191
2192
2193
2194
2195
2196
2197
                    if fp8:
                        out_per_step[i - 1] = cast_from_fp8(
                            out_per_step[i - 1],
                            fp8_meta["scaling_fwd"],
                            META_O_CP,
                            fp8_dtype_forward,
                            TE_DType[torch.float32],
                        )
2198
                    if i == 1:
2199
                        out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
2200
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
2201
                        if causal and qkv_format != "thd":
2202
2203
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
2204
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
2205
                            )
2206
2207
2208
2209
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
2210
                    else:
2211
                        if qkv_format == "thd":
2212
                            tex.thd_second_half_lse_correction(
2213
2214
2215
2216
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
                                max_seqlen_q,
2217
                            )
2218
                        else:
2219
2220
2221
                            flash_attn_fwd_softmax_lse_correction(
                                softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
                            )
2222
2223

                if i < cp_size:
2224
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
2225
2226
2227
2228
2229

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
2230
2231
2232
2233
2234
2235
            if qkv_format == "bshd":
                out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
                out_ = out[:, 1, ...]
            elif qkv_format == "sbhd":
                out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
                out_ = out[1]
2236

2237
            if i <= rank or not causal:
2238
                if qkv_format in ["bshd", "sbhd"]:
2239
2240
2241
2242
2243
2244
2245
                    flash_attn_fwd_out_correction(
                        out.view(*out_per_step[i].shape),
                        out_per_step[i],
                        seq_dim,
                        softmax_lse,
                        softmax_lse_per_step[i],
                    )
2246
                elif qkv_format == "thd":
2247
2248
2249
2250
2251
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2252
                        cu_seqlens_q_padded,
2253
2254
                        False,
                    )
2255
            else:
2256
                if qkv_format in ["bshd", "sbhd"]:
2257
2258
2259
2260
2261
2262
2263
                    flash_attn_fwd_out_correction(
                        out_,
                        out_per_step[i],
                        seq_dim,
                        softmax_lse_[..., 1, :],
                        softmax_lse_per_step[i],
                    )
2264
                elif qkv_format == "thd":
2265
2266
2267
2268
2269
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
2270
                        cu_seqlens_q_padded,
2271
2272
                        True,
                    )
2273
2274

        kv = p2p_comm_buffers[-1]
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
        if qkv_format == "bshd":
            out = out.view(out.shape[0], -1, *out.shape[-2:])
            ctx.batch_size = out.shape[0]
        elif qkv_format == "sbhd":
            out = out.view(-1, *out.shape[-3:])
            ctx.batch_size = out.shape[1]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
            out = flash_attn_a2a_communicate(
                out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
            )
            if use_fused_attention:
                if qkv_format == "bshd":
                    # [b*s, np, hn] -> [b, s, np, hn]
                    out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                elif qkv_format == "sbhd":
                    # [s*b, np, hn] -> [s, b, np, hn]
                    out = out.view(-1, ctx.batch_size, *out.shape[-2:])
        elif not use_fused_attention:
2295
            out = out.view(-1, *out.shape[-2:])
2296

2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
        if fp8 and use_fused_attention:
            amax_cp_fwd = amax_per_step.amax(dim=1)
            fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0]
            fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]

        out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype)
        if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
            out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)

        if fp8 and fp8_meta["recipe"].fp8_mha:
            out_ret = Float8Tensor(
                data=out_fp8,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_O,
                fp8_dtype=fp8_dtype_forward,
                dtype=q_fp8.dtype,
            )
        else:
            out_ret = out_f16

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            q_save, kv_save, out_save = q, kv, out_fp8
            fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
            fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
        elif fp8 and fp8_meta["recipe"].fp8_mha:
2323
2324
2325
2326
2327
2328
2329
2330
            q_fp8 = Float8Tensor(
                data=q,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_QKV,
                fp8_dtype=fp8_dtype_forward,
                dtype=q_fp8.dtype,
            )
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
            kv_fp8 = Float8Tensor(
                data=kv,
                fp8_meta=fp8_meta,
                fp8_meta_forward=True,
                fp8_meta_index=META_QKV,
                fp8_dtype=fp8_dtype_forward,
                dtype=k_fp8.dtype,
            )
            q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None
        else:
2342
            q_f16 = q_f16.view(q.shape)
2343
2344
2345
            q_save, kv_save, out_save = q_f16, kv, out_f16
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None

2346
        ctx.save_for_backward(
2347
2348
2349
            q_save,
            kv_save,
            out_save,
2350
            softmax_lse,
2351
2352
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
2353
2354
            fp8_fwd_scales,
            fp8_fwd_scale_invs,
2355
2356
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
2357
2358
            *rng_states,
            *attn_biases,
2359
        )
2360
2361
2362
        ctx.cp_group_a2a = cp_group_a2a
        ctx.cp_size_a2a = cp_size_a2a
        ctx.rank_a2a = rank_a2a
2363
2364
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
2365
        ctx.cp_stream = cp_stream
2366
        ctx.dropout_p = dropout_p
2367
        ctx.total_tokens_kv = total_tokens_kv
2368
        ctx.max_seqlen_q = max_seqlen_q
2369
        ctx.max_seqlen_kv = max_seqlen_kv
2370
        ctx.softmax_scale = softmax_scale
2371
        ctx.qkv_format = qkv_format
2372
        ctx.attn_mask_type = attn_mask_type
2373
2374
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
2375
        ctx.deterministic = deterministic
2376
        ctx.use_fused_attention = use_fused_attention
2377
2378
2379
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
        return out_ret
2380
2381
2382

    @staticmethod
    def backward(ctx, dout):
2383
2384
2385
        cp_size_a2a = ctx.cp_size_a2a
        rank_a2a = ctx.rank_a2a

2386
2387
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
2388
2389
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
2390
2391
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

2392
        (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
2393
2394
2395
2396
2397
        (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8]
        cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size]
        cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
        rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
        attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
2398

2399
2400
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
2401
        if ctx.qkv_format in ["bshd", "sbhd"]:
2402
            seq_dim = ctx.qkv_format.index("s")
2403
2404
2405
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
2406

2407
        if attn_biases[0] is not None:
2408
2409
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
2410
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
2411
2412
2413
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
2414
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
2415
2416
2417
2418
            )
        else:
            attn_dbias = None

2419
        if causal:
2420
            if ctx.qkv_format == "thd":
2421
2422
2423
                softmax_lse_ = tex.thd_read_second_half_lse(
                    softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q
                )
2424
2425
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
2426
2427
2428
                softmax_lse_ = softmax_lse.view(
                    *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
                )
2429
2430
2431
2432
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
                if ctx.use_fused_attention:
                    # [b, np, sq//2] -> [b, np, sq//2, 1]
                    softmax_lse_.unsqueeze_(-1)
2433
2434
2435
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
2436

2437
        dout_dtype = dout.dtype
2438
2439
        if ctx.fp8:
            if ctx.use_fused_attention:
2440
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
2441
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2442
                fused_attn_qkv_dtype = fp8_dtype_forward
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
                fused_attn_dqkv_dtype = fp8_dtype_backward
                fused_attn_backend = FusedAttnBackend["FP8"]
                dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
                dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
                dkv_fp8_ = torch.empty_like(dkv_fp8)
                if ctx.fp8_meta["recipe"].fp8_mha:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
                    dout = dout._data
                else:
                    dout = cast_to_fp8(
                        dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                    )
                p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
                fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
                fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
                fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
                fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
                fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
                fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
                fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP]
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
2471
2472
2473
2474
2475
2476
2477
                q, kv = [x.from_float8(x.dtype) for x in [q, kv]]
                if cp_size_a2a == 1:
                    dout = dout.from_float8(dout_dtype)
                else:
                    dout_fp8_dtype = dout._fp8_dtype
                    dout_scale_inv = dout._scale_inv
                    dout = dout._data
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
            dq = torch.empty_like(q)
            if ctx.qkv_format == "thd" and causal:
                dq[cu_seqlens_q_padded[-1] :].fill_(0)
            p2p_comm_buffers = [
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            ]
            p2p_comm_buffers[0][0].copy_(kv)
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
2489
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
2490
2491
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
        if cp_size_a2a > 1:
            if not ctx.use_fused_attention:
                out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                dout = dout.view(*out.shape)
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True)
            out, dout = flash_attn_a2a_communicate(
                [out, dout],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                True,
            )
            if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
                dout = cast_from_fp8(
                    dout, None, None, dout_fp8_dtype, TE_DType[dout_dtype], scale_inv=dout_scale_inv
                )

2511
2512
2513
2514
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        send_recv_reqs = []

2515
2516
2517
2518
2519
2520
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

2521
2522
2523
2524
2525
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

2526
2527
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
            if ctx.fp8:
                if i < cp_size - 1:
                    send_recv_reqs = flash_attn_p2p_communicate(
                        rank,
                        send_tensor[0],
                        send_dst,
                        recv_tensor[0],
                        recv_src,
                        ctx.cp_group,
                        batch_p2p_comm,
                    )
                else:
                    dkv_a2a_req = torch.distributed.all_to_all_single(
                        dkv_fp8,
                        dkv_fp8_,
                        group=ctx.cp_group,
                        async_op=True,
                    )
                    send_recv_reqs = [dkv_a2a_req]
            else:
                if i == 0:
                    send_tensor = send_tensor[0]
                    recv_tensor = recv_tensor[0]
                if i == (cp_size - 1):
                    send_tensor = send_tensor[1]
                    recv_tensor = recv_tensor[1]
                send_recv_reqs = flash_attn_p2p_communicate(
                    rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
                )
2557

2558
            kv = p2p_comm_buffers[i % 2][0]
2559
2560
2561
            if ctx.fp8 and ctx.use_fused_attention:
                fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
                fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
2562
            # In reversed order of fwd
2563
            if causal:
2564
                if i == (cp_size - 1):
2565
                    if ctx.use_fused_attention:
2566
2567
2568
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
2569
2570
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
2571
2572
2573
2574
2575
2576
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
2577
2578
                            # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                            kv_ = kv.view(-1, *kv.shape[-4:])
2579
2580
2581
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
2582
2583
                        elif ctx.qkv_format == "thd":
                            q_, kv_, out_, dout_ = q, kv, out, dout
2584
2585
2586
2587
2588
2589
2590
2591
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2592
                        if attn_dbias is not None:
2593
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2594
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2595
                            ctx.max_seqlen_q,
2596
2597
2598
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2599
                            q_,
2600
2601
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2602
2603
                            out_,
                            dout_,
2604
2605
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2606
                            aux_ctx_tensors,
2607
                            fused_attn_backend,
2608
2609
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2610
2611
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2612
                            qkv_layout=qkv_layout,
2613
                            attn_mask_type=ctx.attn_mask_type,
2614
                            attn_bias_type=ctx.attn_bias_type,
2615
2616
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2617
2618
2619
2620
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
2621
                        dq_ = torch.zeros_like(q_)
2622
2623
2624
2625
2626
2627
2628
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
2629
                            fa_optional_backward_kwargs["window_size"] = (-1, 0)
2630
                        _flash_attn_backward(
2631
2632
2633
2634
2635
2636
2637
2638
2639
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2640
2641
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2642
                            ctx.max_seqlen_q,
2643
                            ctx.max_seqlen_kv,
2644
2645
2646
2647
2648
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            True,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2649
                        )
2650
                elif i >= (cp_size - rank - 1):
2651
                    if ctx.use_fused_attention:
2652
2653
2654
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
2655
2656
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                            kv_ = kv[:, 0, ...].contiguous()
2657
2658
2659
2660
2661
2662
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
2663
2664
                            # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                            kv_ = kv[0].contiguous()
2665
2666
2667
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
2668
2669
2670
                        elif ctx.qkv_format == "thd":
                            q_, out_, dout_ = q, out, dout
                            # [2, t, np, hn] -> [2, t/2, np, hn]
2671
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2672
2673
2674
2675
2676
2677
2678
2679
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2680
                        if attn_dbias is not None:
2681
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2682
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2683
                            ctx.max_seqlen_q,
2684
2685
2686
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2687
                            q_,
2688
2689
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2690
2691
                            out_,
                            dout_,
2692
2693
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2694
                            aux_ctx_tensors,
2695
                            fused_attn_backend,
2696
2697
2698
2699
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
2700
2701
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2702
                            qkv_layout=qkv_layout,
2703
                            attn_mask_type="padding" if padding else "no_mask",
2704
                            attn_bias_type=ctx.attn_bias_type,
2705
2706
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2707
2708
2709
2710
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
2711
                        dq_ = torch.zeros_like(q_)
2712
2713
                        if ctx.qkv_format == "thd":
                            # [2, t, np, hn] -> [2, t/2, np, hn]
2714
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2715
2716
2717
                        else:
                            # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
2718
2719
2720
2721
2722
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
2723
                            fa_optional_backward_kwargs["window_size"] = (-1, -1)
2724
                        _flash_attn_backward(
2725
2726
2727
2728
2729
2730
2731
2732
2733
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2734
2735
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2736
                            ctx.max_seqlen_q,
2737
                            ctx.max_seqlen_kv // 2,
2738
2739
2740
2741
2742
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2743
2744
2745
                        )
                else:
                    if ctx.use_fused_attention:
2746
2747
2748
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous()
2749
2750
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
2751
2752
2753
2754
2755
2756
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous()
                            dout_ = dout[:, 1, ...].contiguous()
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            q_ = q[1].contiguous()
2757
2758
                            # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                            kv_ = kv.view(-1, *kv.shape[-4:])
2759
2760
2761
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            out_ = out[1].contiguous()
                            dout_ = dout[1].contiguous()
2762
2763
                        elif ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
2764
2765
2766
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
2767
                            kv_ = kv
2768
2769
2770
2771
2772
2773
2774
2775
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse_,
                                softmax_lse_,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
2776
                        if attn_dbias is not None:
2777
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2778
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2779
                            ctx.max_seqlen_q // 2,
2780
2781
2782
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2783
                            q_,
2784
2785
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2786
2787
                            out_,
                            dout_,
2788
2789
                            fused_attn_qkv_dtype,
                            fused_attn_dqkv_dtype,
2790
                            aux_ctx_tensors,
2791
                            fused_attn_backend,
2792
2793
2794
2795
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2796
2797
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2798
                            qkv_layout=qkv_layout,
2799
                            attn_mask_type="padding" if padding else "no_mask",
2800
                            attn_bias_type=ctx.attn_bias_type,
2801
2802
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2803
2804
                        )
                    else:
2805
2806
                        if ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
2807
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
2808
2809
2810
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
2811
                        dq_ = torch.zeros_like(q_)
2812
2813
2814
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
2815
                        if ctx.qkv_format == "thd":
2816
2817
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
2818
2819
2820
2821
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                            dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
2822
                        if _flash_attn_2_3_plus:
2823
                            fa_optional_backward_kwargs["window_size"] = (-1, -1)
2824
                        _flash_attn_backward(
2825
2826
2827
2828
2829
2830
2831
2832
2833
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse_,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2834
2835
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2836
                            ctx.max_seqlen_q // 2,
2837
                            ctx.max_seqlen_kv,
2838
2839
2840
2841
2842
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2843
2844
2845
                        )
            else:
                if ctx.use_fused_attention:
2846
2847
2848
2849
                    if ctx.fp8:
                        aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
                    else:
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2850
                    if attn_dbias is not None:
2851
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2852
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2853
                        ctx.max_seqlen_q,
2854
2855
2856
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
2857
                        q,
2858
2859
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
2860
2861
                        out,
                        dout,
2862
2863
                        fused_attn_qkv_dtype,
                        fused_attn_dqkv_dtype,
2864
                        aux_ctx_tensors,
2865
                        fused_attn_backend,
2866
2867
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2868
2869
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
2870
                        qkv_layout=qkv_layout,
2871
                        attn_mask_type=ctx.attn_mask_type,
2872
                        attn_bias_type=ctx.attn_bias_type,
2873
2874
                        deterministic=ctx.deterministic,
                        **fp8_meta_kwargs,
2875
2876
2877
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
2878
                    q_ = q.view(-1, *q.shape[-2:])
2879
                    dq_ = torch.zeros_like(q_)
2880
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
2881
2882
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
2883
                    # [b, sq, np, hn] -> [b*sq, np, hn]
2884
2885
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
2886
                    if _flash_attn_2_3_plus:
2887
                        fa_optional_backward_kwargs["window_size"] = (-1, -1)
2888
                    _flash_attn_backward(
2889
2890
2891
2892
2893
2894
2895
2896
2897
                        dout_,
                        q_,
                        kv_[0],
                        kv_[1],
                        out_,
                        softmax_lse,
                        dq_,
                        dkv_[0],
                        dkv_[1],
2898
2899
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
2900
                        ctx.max_seqlen_q,
2901
                        ctx.max_seqlen_kv,
2902
2903
2904
                        ctx.dropout_p,
                        ctx.softmax_scale,
                        False,
2905
                        rng_state=rng_states[cp_size - i - 1],
2906
                        **fa_optional_backward_kwargs,
2907
2908
                    )

2909
2910
            if ctx.fp8:
                dq = dq_fp8[(rank + i + 1) % cp_size]
2911
            if i >= (cp_size - rank - 1) or not causal:
2912
2913
2914
2915
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
2916
2917
2918
2919
2920
2921
                if ctx.qkv_format == "bshd":
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
                elif ctx.qkv_format == "sbhd":
                    # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
                    dq_ = dq_.view(-1, *dq.shape[-3:])
2922

2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
            if ctx.fp8:
                if i >= (cp_size - rank - 1) or not causal:
                    dq.copy_(dq_)
                else:
                    if ctx.qkv_format == "bshd":
                        dq[:, 0, ...].fill_(0)
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[0].fill_(0)
                        dq[1].copy_(dq_)
            elif causal:
2934
                if i > (cp_size - rank - 1):
2935
                    dq.add_(dq_)
2936
2937
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
2938
2939
                        dq.copy_(dq_)
                    else:
2940
2941
2942
2943
2944
2945
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
2946
                        elif ctx.qkv_format == "thd":
2947
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
2948
                elif i > 0:
2949
2950
2951
2952
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
2953
                    elif ctx.qkv_format == "thd":
2954
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
2955
                else:
2956
2957
2958
2959
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
2960
                    elif ctx.qkv_format == "thd":
2961
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
2962
2963
2964
2965
2966
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
2967

2968
            if attn_dbias is not None:
2969
                idx = (rank + i + 1) % cp_size
2970
                if i == (cp_size - 1) or not causal:
2971
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
2972
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2973
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
2974
2975
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
2976
2977
2978
2979
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
2980
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2981
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
2982
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
2983

2984
2985
2986
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
2987

2988
2989
2990
2991
2992
2993
2994
            if ctx.fp8:
                if i < cp_size - 1:
                    dkv = dkv_fp8_[(rank + i + 1) % cp_size]
                else:
                    dkv = dkv_fp8[(rank + i + 1) % cp_size]
            else:
                dkv = p2p_comm_buffers[(i + 1) % 2][1]
2995
2996
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
2997
2998
2999
3000
                if ctx.qkv_format in ["bshd", "sbhd"]:
                    # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
3001
            if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
3002
3003
3004
3005
3006
3007
                if ctx.qkv_format == "bshd":
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                elif ctx.qkv_format == "sbhd":
                    # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
3008
3009
3010
3011
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
3012

3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
            if ctx.fp8:
                if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
                    if ctx.qkv_format == "bshd":
                        dkv[:, :, 0, ...].copy_(dkv_)
                        dkv[:, :, 1, ...].fill_(0)
                    elif ctx.qkv_format == "sbhd":
                        dkv[:, 0, ...].copy_(dkv_)
                        dkv[:, 1, ...].fill_(0)
                else:
                    dkv.copy_(dkv_)
            elif causal:
3024
                if i == (cp_size - 1):
3025
                    if rank == 0:
3026
3027
3028
3029
3030
3031
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
3032
                        elif ctx.qkv_format == "thd":
3033
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
3034
3035
                    else:
                        dkv.add_(dkv_)
3036
3037
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
3038
3039
3040
3041
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
3042
                        elif ctx.qkv_format == "thd":
3043
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
3044
                    else:
3045
3046
3047
3048
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
3049
                        elif ctx.qkv_format == "thd":
3050
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
3051
3052
3053
3054
3055
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
3056
3057
3058
3059
3060
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
        if ctx.fp8 and ctx.use_fused_attention:
            amax_cp_bwd = amax_per_step.amax(dim=1)
            ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0]
            ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1]
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
                # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
                dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
            dq, dkv = [
                cast_from_fp8(
                    x,
                    ctx.fp8_meta["scaling_bwd"],
                    META_DQKV_CP,
                    fp8_dtype_backward,
                    TE_DType[torch.float32],
                )
                for x in [dq_fp8, dkv_fp8]
            ]
            dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]

3081
        if causal:
3082
3083
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
3084
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
3085
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
3086
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
3087
3088
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
3089
                dq = dq.view(-1, *dq.shape[-3:])
3090
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
3091
3092
3093
3094
3095
3096
3097
3098
3099
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

        if ctx.qkv_format == "thd":
            dkv_ = torch.empty(
                2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device
            )
            dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv)
            dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
            dkv = dkv_
3100

3101
3102
3103
3104
3105
        if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
            dq, dkv = [
                cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
                for x in [dq, dkv]
            ]
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
        dk, dv = dkv[0], dkv[1]

        if cp_size_a2a > 1:
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False)
            dq, dk, dv = flash_attn_a2a_communicate(
                [dq, dk, dv],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                False,
            )
            if ctx.qkv_format == "bshd":
                dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
            elif ctx.qkv_format == "sbhd":
                dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
3125
3126
3127
3128
3129
3130
3131
3132
3133
            dq, dk, dv = [
                Float8Tensor(
                    data=x,
                    fp8_meta=ctx.fp8_meta,
                    fp8_meta_forward=False,
                    fp8_meta_index=META_DQKV,
                    fp8_dtype=fp8_dtype_backward,
                    dtype=dout_dtype,
                )
3134
                for x in [dq, dk, dv]
3135
3136
            ]

3137
3138
3139
3140
        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

3141
3142
3143
        return (
            None,
            dq,
3144
3145
            dk,
            dv,
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3157
            attn_dbias,
3158
3159
3160
3161
3162
            None,
            None,
            None,
            None,
            None,
3163
3164
            None,
            None,
3165
        )
3166
3167


3168
3169
def get_kv_seq_info_after_all_gather(
    local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
3170
):
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
    """Compute KV sequence index range and update window size after all-gather."""
    local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
    full_seq_end_idx = max_seqlen_kv * cp_size * 2

    if window_size is None:
        window_size = (-1, 0) if causal else (-1, -1)

    if window_size[1] == -1:
        seq_end_idx = full_seq_end_idx
        window_size_right = -1
    else:
        seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
        window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx

    if window_size[0] == -1:
        seq_start_idx = 0
        window_size_left = -1
    else:
        seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
        window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx

    return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
3193
3194
3195
3196


class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    """
3197
3198
    Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
    Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
3221
3222
        cp_group,
        cp_stream,
3223
3224
3225
3226
3227
3228
3229
3230
3231
    ):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
3232
        assert not padding, f"{attn_mask_type} mask type is not supported!"
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
        if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
            attn_mask_type = attn_mask_type + "_bottom_right"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            use_fused_attention or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
3243
3244
        if _flash_attn_2_5_7_plus:
            fa_optional_forward_kwargs["block_table"] = None
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        max_seqlen_q = max_seqlen_q // (2 * cp_size)
        max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
        cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
        cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)

3259
3260
3261
3262
        # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
        q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
        # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
        k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
3263

3264
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3265
3266
        k_ag, _ = gather_along_first_dim(k, cp_group)
        v_ag, _ = gather_along_first_dim(v, cp_group)
3267
3268

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3269
3270
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        cp_stream.wait_stream(torch.cuda.current_stream())

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
3281
3282

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
3283
3284
3285
        kv_seq_range_per_step = [None, None]
        window_size_per_step = [None, None]
        cu_seqlens_kv_per_step = [None, None]
3286
3287
3288
3289
3290
3291
3292
3293
        out_per_step = [None, None]
        softmax_lse_per_step = [None, None]
        rng_states = [None, None]
        out = torch.empty_like(q)

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3294
3295
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3296
3297
3298
3299
3300
3301
3302
3303
3304
                    q_ = q.select(seq_dim, i).contiguous()
                    kv_seq_range_per_step[i], window_size_per_step[i] = (
                        get_kv_seq_info_after_all_gather(
                            local_seq_chunk_ids[i],
                            cp_size,
                            max_seqlen_q,
                            max_seqlen_kv,
                            window_size,
                            causal,
3305
                        )
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
                    )
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv_ = seq_end_idx - seq_start_idx
                    cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
                        k.shape[1], max_seqlen_kv_, k.device
                    )
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
3318
3319
3320
3321
                    if use_fused_attention:
                        out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
                            is_training,
                            max_seqlen_q,
3322
                            max_seqlen_kv_,
3323
                            cu_seqlens_q,
3324
                            cu_seqlens_kv_per_step[i],
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
                            q_,
                            k_,
                            v_,
                            TE_DType[q.dtype],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=softmax_scale,
                            dropout=dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=attn_mask_type,
                            attn_bias_type=attn_bias_type,
                            attn_bias=attn_bias,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
3337
3338
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
                            window_size=window_size_per_step[i],
3339
3340
3341
3342
3343
3344
3345
3346
3347
                        )
                    else:
                        q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
                        _, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = (
                            _flash_attn_forward(
                                q_,
                                k_,
                                v_,
                                cu_seqlens_q,
3348
                                cu_seqlens_kv_per_step[i],
3349
                                max_seqlen_q,
3350
                                max_seqlen_kv_,
3351
3352
                                dropout_p,
                                softmax_scale,
3353
                                causal=causal,
3354
                                return_softmax=False,
3355
                                window_size=window_size_per_step[i],
3356
3357
3358
3359
3360
3361
3362
                                **fa_optional_forward_kwargs,
                            )
                        )

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if qkv_format == "bshd":
3363
                        out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape))
3364
                    elif qkv_format == "sbhd":
3365
                        out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape))
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382

        torch.cuda.current_stream().wait_stream(cp_stream)

        if use_fused_attention:
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
        else:
            out = out.view(-1, *out.shape[-2:])

        ctx.save_for_backward(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_q_padded,
3383
            *cu_seqlens_kv_per_step,
3384
3385
3386
3387
            *out_per_step,
            *softmax_lse_per_step,
            *rng_states,
        )
3388
3389
        ctx.kv_seq_range_per_step = kv_seq_range_per_step
        ctx.window_size_per_step = window_size_per_step
3390
3391
3392
3393
3394
3395
3396
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_bias_type = attn_bias_type
3397
        ctx.attn_mask_type = attn_mask_type
3398
3399
3400
3401
3402
3403
3404
3405
3406
        ctx.deterministic = deterministic
        ctx.use_fused_attention = use_fused_attention
        return out

    @staticmethod
    def backward(ctx, dout):
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)

3407
3408
3409
3410
3411
3412
3413
        (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5]
        cu_seqlens_kv_per_step = ctx.saved_tensors[5:7]
        out_per_step = ctx.saved_tensors[7:9]
        softmax_lse_per_step = ctx.saved_tensors[9:11]
        rng_states = ctx.saved_tensors[11:13]
        kv_seq_range_per_step = ctx.kv_seq_range_per_step
        window_size_per_step = ctx.window_size_per_step
3414

3415
        seq_dim = ctx.qkv_format.index("s")
3416
3417
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

3418
        dout = dout.view(q.shape)
3419
        dq = torch.empty_like(q)
3420
        dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
        dv = torch.zeros_like(dk)
        dq_per_step = [None, None]
        dk_per_step = [None, None]
        dv_per_step = [None, None]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
        # synchronize dkv update across steps
        dkv_update_done = torch.cuda.Event()

3431
        # [s, b, np, hn] -> [cp, s, b, np, hn]
3432
3433
        k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
        v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
3434
3435

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
3436
3437
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
3438
3439
3440
3441
3442
3443
3444
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        ctx.cp_stream.wait_stream(torch.cuda.current_stream())
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]

        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
3457
3458
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
3459
3460
3461
3462
3463
3464
3465
3466
3467
                    q_ = q.select(seq_dim, i).contiguous()
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv = seq_end_idx - seq_start_idx
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
3468
                    out_ = out_per_step[i]
3469
                    dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
3470
3471
3472
3473
                    if ctx.use_fused_attention:
                        aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
                            ctx.max_seqlen_q,
3474
                            max_seqlen_kv,
3475
                            cu_seqlens_q,
3476
                            cu_seqlens_kv_per_step[i],
3477
3478
3479
3480
3481
3482
                            q_,
                            k_,
                            v_,
                            out_,
                            dout_,
                            TE_DType[q.dtype],
3483
                            TE_DType[dout.dtype],
3484
3485
3486
                            aux_ctx_tensors,
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
3487
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
3488
3489
3490
3491
3492
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=ctx.attn_mask_type,
                            attn_bias_type=ctx.attn_bias_type,
3493
3494
                            window_size=window_size_per_step[i],
                            deterministic=ctx.deterministic,
3495
3496
                        )
                    else:
3497
                        batch_size = k_.shape[0]
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
                        q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
                        _flash_attn_backward(
                            dout_,
                            q_,
                            k_,
                            v_,
                            out_,
                            softmax_lse_per_step[i],
                            dq_per_step[i],
                            dk_per_step[i],
                            dv_per_step[i],
                            cu_seqlens_q,
3513
                            cu_seqlens_kv_per_step[i],
3514
                            ctx.max_seqlen_q,
3515
                            max_seqlen_kv,
3516
3517
                            ctx.dropout_p,
                            ctx.softmax_scale,
3518
3519
                            "causal" in ctx.attn_mask_type,
                            window_size=window_size_per_step[i],
3520
3521
3522
                            rng_state=rng_states[i],
                            **fa_optional_backward_kwargs,
                        )
3523
3524
3525
3526
3527
3528
3529
                        # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                        dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape)
                        # [b*s_range, np, hn] -> [b, s_range, np, hn]
                        dk_per_step[i], dv_per_step[i] = [
                            x.view(batch_size, -1, *x.shape[-2:])
                            for x in [dk_per_step[i], dv_per_step[i]]
                        ]
3530
3531
3532
3533

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if ctx.qkv_format == "bshd":
3534
                        dq[:, i - 1].copy_(dq_per_step[i - 1])
3535
                    elif ctx.qkv_format == "sbhd":
3536
3537
3538
3539
3540
3541
                        dq[i - 1].copy_(dq_per_step[i - 1])
                    # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
                    dk_per_step[i - 1], dv_per_step[i - 1] = [
                        x.movedim(seq_dim, 0).contiguous()
                        for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
                    ]
3542
3543
3544
                    # wait until dkv update of last step is done
                    if i > 1:
                        flash_attn_streams[i - 1].wait_event(dkv_update_done)
3545
3546
3547
3548
3549
3550
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i - 1][0],
                        kv_seq_range_per_step[i - 1][1],
                    )
                    dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
                    dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
3551
3552
3553
3554
3555
                    if i < len(local_seq_chunk_ids):
                        flash_attn_streams[i - 1].record_event(dkv_update_done)

        torch.cuda.current_stream().wait_stream(ctx.cp_stream)

3556
3557
3558
3559
3560
3561
3562
        # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
        dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
        dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False)
        dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
        dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
3563
3564
3565
3566
3567
        dk = dk.view(-1, *dk.shape[-3:])
        dv = dv.view(-1, *dv.shape[-3:])
        dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
        dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)

3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881
3882
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
3910
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924
3925
3926
3927
3928
3929
3930
3931
3932
3933
3934
3935
3936
3937
3938
3939
3940
3941
3942
3943
3944
3945
3946
3947
3948
3949
3950
3951
3952
3953
3954
3955
3956
3957
3958
3959
3960
3961
3962
3963
3964
3965
3966
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988
3989
3990
3991
3992
3993
3994
3995
3996
        dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
        dk = dk.movedim(0, seq_dim).contiguous()
        dv = dv.movedim(0, seq_dim).contiguous()

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
    """
    Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
    Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
        fp8,
        fp8_meta,
        cp_group,
        cp_stream,
    ):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
        assert not padding, f"{attn_mask_type} mask type is not supported!"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            window_size == (-1, 0)
            or window_size == (-1, -1)
            or use_fused_attention
            or _flash_attn_2_3_plus
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_forward_kwargs["window_size"] = window_size
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_5_7_plus:
            fa_optional_forward_kwargs["block_table"] = None

        assert (
            q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
        ), "The number of attention heads needs to be divisible by CP size!"

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        batch_dim = qkv_format.index("b")
        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        if fp8:
            if use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_backend = FusedAttnBackend["FP8"]
                if fp8_meta["recipe"].fp8_mha:
                    assert (
                        isinstance(q, Float8Tensor)
                        and isinstance(k, Float8Tensor)
                        and isinstance(v, Float8Tensor)
                    ), "q/k/v must be Float8Tensors for FP8 MHA!"
                    fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                    q_f16, k_f16, v_f16 = q, k, v
                    q, k, v = [
                        cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                        for x in [q_f16, k_f16, v_f16]
                    ]
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
                fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
                fp8_meta_kwargs["d_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_s_offset"] = META_S
                fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
                fp8_meta_kwargs["q_scale_o_offset"] = META_O
                fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history
                fp8_meta_kwargs["amax_s_offset"] = META_S
                fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history
                fp8_meta_kwargs["amax_o_offset"] = META_O
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True)
        q, k, v = flash_attn_a2a_communicate(
            [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
        )

        if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            q_f16, k_f16, v_f16 = q, k, v
            q, k, v = [
                cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
                for x in [q_f16, k_f16, v_f16]
            ]

        batch_size = q.shape[batch_dim]
        if use_fused_attention:
            out, aux_ctx_tensors = fused_attn_fwd(
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                fused_attn_qkv_dtype,
                fused_attn_backend,
                attn_scale=softmax_scale,
                dropout=dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                attn_bias=attn_bias,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                window_size=window_size,
                **fp8_meta_kwargs,
            )
        else:
            # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn]
            q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]]
            (
                _,
                _,
                _,
                _,
                out,
                softmax_lse,
                _,
                rng_state,
            ) = _flash_attn_forward(
                q,
                k,
                v,
                cu_seqlens_q,
                cu_seqlens_kv,
                max_seqlen_q,
                max_seqlen_kv,
                dropout_p,
                softmax_scale,
                causal=causal,
                return_softmax=False,
                **fa_optional_forward_kwargs,
            )
            aux_ctx_tensors = [softmax_lse, rng_state]
            # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn]
            out = out.view(batch_size, -1, *out.shape[-2:])

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False)
        out = flash_attn_a2a_communicate(
            out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
        )

        if use_fused_attention:
            if qkv_format == "bshd":
                # [b*s, np, hn] -> [b, s, np, hn]
                out = out.view(batch_size, -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                # [s*b, np, hn] -> [s, b, np, hn]
                out = out.view(-1, batch_size, *out.shape[-2:])

        if fp8:
            if fp8_meta["recipe"].fp8_mha:
                out_fp8 = Float8Tensor(
                    data=out,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q_fp8.dtype,
                )
                out = out_fp8._data
                out_ret = out_fp8
            else:
                out_f16 = cast_from_fp8(
                    out,
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    TE_DType[q_f16.dtype],
                )
                out_ret = out_f16
        else:
            out_ret = out

        if fp8:
            if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                q_save, k_save, v_save, out_save = q, k, v, out
            elif fp8_meta["recipe"].fp8_mha:
                q_fp8, k_fp8, v_fp8 = [
                    Float8Tensor(
                        data=x,
                        fp8_meta=fp8_meta,
                        fp8_meta_forward=True,
                        fp8_meta_index=META_QKV,
                        fp8_dtype=fp8_dtype_forward,
                        dtype=out_fp8.dtype,
                    )
                    for x in [q, k, v]
                ]
                q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8
            else:
                q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16
        else:
            q_save, k_save, v_save, out_save = q, k, v, out

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
            fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
            fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
        else:
            fp8_fwd_scales, fp8_fwd_scale_invs = None, None

        ctx.save_for_backward(
            q_save,
            k_save,
            v_save,
            out_save,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            fp8_fwd_scales,
            fp8_fwd_scale_invs,
            *aux_ctx_tensors,
        )
        ctx.batch_size = batch_size
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_mask_type = attn_mask_type
        ctx.attn_bias_type = attn_bias_type
        ctx.deterministic = deterministic
        ctx.window_size = window_size
        ctx.use_fused_attention = use_fused_attention
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
        return out_ret

    @staticmethod
    def backward(ctx, dout):
        cp_size = get_distributed_world_size(ctx.cp_group)

        q, k, v, out = ctx.saved_tensors[:4]
        cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[
            4:8
        ]
        fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10]
        aux_ctx_tensors = ctx.saved_tensors[10:]

        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
        causal = "causal" in ctx.attn_mask_type
        seq_dim = ctx.qkv_format.index("s")

        if ctx.fp8:
            if ctx.use_fused_attention:
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
                fused_attn_qkv_dtype = fp8_dtype_forward
                fused_attn_dqkv_dtype = fp8_dtype_backward
                fused_attn_backend = FusedAttnBackend["FP8"]
                if ctx.fp8_meta["recipe"].fp8_mha:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
                    dout_fp8 = dout
                    dout = dout_fp8._data
                else:
                    dout_f16 = dout
                    dout = cast_to_fp8(
                        dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                    )
                fp8_meta_kwargs = {}
                fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
                fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
                fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
                fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
                fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
                fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
                fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
                fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV]
                fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP]
                fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][
                    META_DQKV
                ]
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
                assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]]
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_qkv_dtype = TE_DType[q.dtype]
                fused_attn_dqkv_dtype = TE_DType[dout.dtype]
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if not ctx.use_fused_attention:
            out = out.view(ctx.batch_size, -1, *out.shape[-2:])
        dout = dout.view(*out.shape)

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True)
        out, dout = flash_attn_a2a_communicate(
            [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
        )

        fa_optional_backward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_backward_kwargs["window_size"] = ctx.window_size
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

        if ctx.use_fused_attention:
            dq, dk, dv, _ = fused_attn_bwd(
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                out,
                dout,
                fused_attn_qkv_dtype,
                fused_attn_dqkv_dtype,
                aux_ctx_tensors,
                fused_attn_backend,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                attn_scale=ctx.softmax_scale,
                dropout=ctx.dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=ctx.attn_mask_type,
                attn_bias_type=ctx.attn_bias_type,
                window_size=ctx.window_size,
                deterministic=ctx.deterministic,
                **fp8_meta_kwargs,
            )
        else:
            softmax_lse, rng_state = aux_ctx_tensors
            out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]]
            dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
            _flash_attn_backward(
                dout,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.softmax_scale,
                causal,
                rng_state=rng_state,
                **fa_optional_backward_kwargs,
            )
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]

        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False)
        dq, dk, dv = flash_attn_a2a_communicate(
            [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
        )

3997
        if ctx.qkv_format == "bshd":
3998
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
3999
        elif ctx.qkv_format == "sbhd":
4000
4001
4002
4003
4004
4005
4006
4007
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
4020
4021
4022
4023
4024
4025
            dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8:
            if ctx.fp8_meta["recipe"].fp8_mha:
                dq, dk, dv = [
                    Float8Tensor(
                        data=x,
                        fp8_meta=ctx.fp8_meta,
                        fp8_meta_forward=False,
                        fp8_meta_index=META_DQKV,
                        fp8_dtype=fp8_dtype_backward,
                        dtype=dout_fp8.dtype,
                    )
                    for x in [dq, dk, dv]
                ]
            else:
                dq, dk, dv = [
                    cast_from_fp8(
                        x,
                        ctx.fp8_meta["scaling_bwd"],
                        META_DQKV,
                        fp8_dtype_backward,
                        TE_DType[dout_f16.dtype],
                    )
                    for x in [dq, dk, dv]
                ]
4026
4027
4028
4029
4030
4031
4032
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
4047
4048

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4049
4050
4051
            None,
            None,
            None,
4052
4053
4054
        )


4055
def attn_forward_func_with_cp(
4056
4057
4058
4059
4060
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
4061
    cu_seqlens_kv,
4062
    max_seqlen_q,
4063
    max_seqlen_kv,
4064
4065
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
4066
4067
4068
4069
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
4070
    cp_comm_type,
4071
4072
4073
4074
4075
4076
4077
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
4078
    window_size=None,
4079
4080
    fp8=False,
    fp8_meta=None,
4081
) -> torch.Tensor:
4082
4083
4084
4085
    """
    Attention implementation with context parallelism.
    """

4086
4087
4088
4089
4090
4091
4092
4093
4094
4095
4096
4097
4098
4099
4100
4101
    if cp_comm_type == "a2a+p2p":
        assert isinstance(
            cp_group, list
        ), "Hierarchical CP implementation needs multi-level CP groups!"
        assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
        if get_distributed_world_size(cp_group[0]) == 1:
            cp_group = cp_group[1]
            cp_comm_type = "p2p"
        elif get_distributed_world_size(cp_group[1]) == 1:
            cp_group = cp_group[0]
            cp_comm_type = "a2a"
    else:
        assert isinstance(
            cp_group, dist_group_type
        ), f"Unsupported process group for CP communication type {cp_comm_type}!"

4102
4103
4104
4105
4106
4107
4108
4109
4110
4111
4112
4113
4114
4115
4116
4117
4118
4119
4120
4121
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert (
        qkv_format != "thd"
        or not use_fused_attention
        or attn_mask_type in ["padding", "padding_causal"]
    ), (
        f"Context parallelism is not supported for {attn_mask_type} mask type and "
        f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
    )
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
4122
4123
4124
    assert (
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
    ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
4125
4126
4127

    sliding_window_attn = (
        window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
4128
    )
4129
4130
4131
4132
4133
    assert (
        not sliding_window_attn
        or cp_comm_type == "a2a"
        or (cp_comm_type == "all_gather" and not use_fused_attention)
    ), "The context parallel running configs cannot support sliding window attetnion!"
4134

4135
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
    args = [
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ]

4156
    if cp_comm_type in ["p2p", "a2a+p2p"]:
4157
4158
4159
4160
4161
4162
4163
4164
4165
4166
        args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream]
        out = AttnFuncWithCPAndKVP2P.apply(*args)
    elif cp_comm_type == "all_gather":
        args.pop(5)
        args.pop(8)
        args += [window_size, cp_group, cp_stream]
        out = AttnFuncWithCPAndKVAllGather.apply(*args)
    elif cp_comm_type == "a2a":
        args += [window_size, fp8, fp8_meta, cp_group, cp_stream]
        out = AttnFuncWithCPAndQKVOA2A.apply(*args)
4167
4168
4169
    else:
        raise ValueError(f"Unsupported communication type: {cp_comm_type}!")

4170
4171
4172
    return out


4173
4174
4175
4176
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
4177

4178
4179
4180
    def __init__(
        self,
        dim: int,
4181
        rotary_percent: float = 1.0,
4182
4183
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
4184
        rotary_base: float = 10000.0,
4185
4186
4187
4188
4189
4190
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
4191
4192
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
4193
4194
4195
4196
4197
4198
4199
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
4200
4201
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
4202
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
4203
        self.rotary_base = rotary_base
4204
        inv_freq = 1.0 / (
4205
            self.rotary_base
4206
4207
4208
4209
4210
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
4211
        self.register_buffer("inv_freq", inv_freq)
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
4223
4224
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
4225
4226
4227
4228
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
4229

4230
4231
4232
4233
4234
4235
4236
4237
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
4238
4239
4240
4241
4242
4243
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

4244
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
4245
4246
4247
4248
4249
4250
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
4267
4268

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
4269
4270
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
4271
4272
4273
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
4274
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
4275
4276
4277
4278
4279
4280
4281
4282
4283
4284
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
4285
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
4299
4300
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


4301
4302
4303
4304
4305
4306
4307
4308
4309
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


4310
def apply_rotary_pos_emb(
4311
4312
4313
4314
4315
4316
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
4317
    """
4318
    Apply rotary positional embedding tensor to the input tensor.
4319

4320
4321
4322
    Parameters
    ----------
    t: torch.Tensor
4323
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
4324
4325
4326
4327
4328
4329
4330
4331
4332
4333
4334
4335
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
4336
    """
4337
4338
4339
4340
4341
4342
4343
4344
4345
4346
4347
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

4348
4349
4350
4351
4352
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
4353
4354
4355
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
4356
    freqs = freqs[:cur_seq_len]
4357
    if tensor_format == "bshd":
4358
4359
4360
4361
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
4362

4363
4364
4365
4366
4367
4368
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
4369
    t = (t * cos_) + (_rotate_half(t) * sin_)
4370
4371
4372
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
4373
class _SplitAlongDim(torch.autograd.Function):
4374
4375
4376
    """"""

    @staticmethod
4377
4378
4379
4380
4381
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
4382
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
4383
4384
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
4385
        if isinstance(mixed_x_layer, Float8Tensor):
4386
4387
4388
4389
4390
4391
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
                    data=x,
                )
                for x in torch.split(
4392
4393
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
4394
4395
4396
4397
                    dim=split_dim,
                )
            )
        return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
4398
4399

    @staticmethod
4400
    def backward(ctx, *grad_outputs):
4401
4402
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
4403
4404
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
4405
4406
4407
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
4408
4409
4410
4411
4412
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

4413
4414
4415
4416
4417
4418
4419
4420
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
4421
4422
4423
4424
4425
4426
4427
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
4428
4429
4430
                    noop_ok = False
                    break
            if noop_ok:
4431
4432
4433
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
4434
4435
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
4436
4437
4438
4439
4440
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
4441
4442
4443
4444
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
4445
4446
4447
4448
4449
4450
4451
            return (
                Float8Tensor.make_like(
                    grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim)
                ),
                None,
                None,
            )
4452
4453
        noop_ok = True
        strides = grad_outputs[0].stride()
4454
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
4455
        shape = list(grad_outputs[0].shape)
4456
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
4457
4458
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
4459
4460
4461
4462
4463
4464
4465
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
4466
4467
4468
                noop_ok = False
                break
        if noop_ok:
4469
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
4470
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
4471
            new_shape[split_dim] = sum(split_sizes)
4472
4473
4474
4475
4476
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
4477
            )
cyanguwa's avatar
cyanguwa committed
4478
            return ret, None, None
4479

4480
        return torch.cat(grad_outputs, dim=split_dim), None, None
4481
4482
4483
4484
4485
4486
4487
4488
4489


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
4490
        softmax_scale: float,
4491
        attention_type: str = "self",
4492
4493
4494
4495
4496
4497
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

4498
        self.softmax_scale = softmax_scale
4499
        self.attention_type = attention_type
4500
4501
4502
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

4503
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
4504
4505
4506
4507
4508
4509

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

4510
4511
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
4512
4513
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
4514

4515
4516
4517
4518
4519
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4520
        qkv_layout: str = "sbh3d",
4521
4522
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
4523
        attn_mask_type: str = "causal",
4524
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4525
4526
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
4527
        alibi_slopes: Optional[torch.Tensor] = None,
4528
    ) -> torch.Tensor:
4529
        """Unfused attention fprop"""
4530
4531
4532
4533
4534
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
4535
            # convert to sbhd and use sbhd implementation for now
4536
4537
4538
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
4539
4540
4541
4542
4543
4544
4545
4546
4547
4548
4549
4550
4551
4552
4553
4554
4555
4556
4557
4558
4559
4560
4561
4562
4563
4564
4565
4566
4567
4568
4569
4570
4571
4572
4573
4574
4575
4576
4577
4578
4579
4580
4581
4582
4583
4584
4585
4586
4587
4588
4589
4590
        batch_size, max_seqlen_q, max_seqlen_kv = (
            query_layer.shape[1],
            query_layer.shape[0],
            key_layer.shape[0],
        )
        if "padding" in attn_mask_type:
            if self.attention_type == "self":
                assert attention_mask.shape == (
                    batch_size,
                    1,
                    1,
                    max_seqlen_q,
                ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
                attention_mask = torch.logical_or(
                    attention_mask.squeeze(1).unsqueeze(3), attention_mask
                )
            else:
                assert (
                    len(attention_mask) == 2
                    and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
                    and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
                ), (
                    "attention_mask should be a tuple of two tensors with shapes "
                    "[b, 1, 1, sq] and [b, 1, 1, skv]!"
                )
                attention_mask = torch.logical_or(
                    attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
                )
            mask = attention_mask.squeeze(1).logical_not()
            actual_seqlens_q = mask[:, :, 0].sum(dim=1)
            actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
            mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
                1, 1, max_seqlen_q, 1
            ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
                1, 1, 1, max_seqlen_kv
            )
            if attn_mask_type == "padding_causal":
                attention_mask = torch.logical_or(
                    torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
                    attention_mask,
                )
            if attn_mask_type == "padding_causal_bottom_right":
                attention_mask = torch.logical_or(
                    torch.where(
                        mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
                        + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
                        < 0,
                        1,
                        0,
                    ),
                    attention_mask,
                )
4591

4592
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
4593
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
4594
4595
4596
4597
4598
4599
4600
4601
4602

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

4603
        if key_layer.shape[2] != query_layer.shape[2]:
4604
4605
4606
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
4607
            key_layer = key_layer.repeat_interleave(
4608
4609
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
4610
            value_layer = value_layer.repeat_interleave(
4611
4612
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
4613

4614
        # [sq, b, np, hn] -> [sq, b * np, hn]
4615
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
4616
4617
4618
4619
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
4620
4621
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
4622
4623
4624
4625
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
4626
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
4627
4628
4629
            device=torch.cuda.current_device(),
        )

4630
4631
4632
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

4633
        scale = self.softmax_scale
4634
        if apply_qk_layer_scaling:
4635
            scale /= self.layer_number
4636
4637

        # Raw attention scores. [b * np, sq, sk]
4638
4639
4640
4641
4642
4643
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
4644
                alpha=scale,
4645
            ).view(*output_size)
4646
4647
4648
4649
4650
4651
4652

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
4653
            matmul_result = matmul_result.view(*output_size) + core_attention_bias
4654
            matmul_result *= scale
4655

4656
4657
4658
4659
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
4660
                _, core_attention_bias = get_alibi(
4661
4662
4663
                    output_size[1],
                    output_size[2],
                    output_size[3],
4664
4665
                    actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
                    actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
4666
4667
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
4668
                )
4669
4670
4671
4672
4673
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
4674
                alpha=scale,
4675
            )
4676
4677
            matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
                dtype=query_layer.dtype
4678
            )
4679
4680
4681

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
4682
        attention_probs = self.scale_mask_softmax(
4683
            matmul_result, attention_mask, attn_mask_type, softmax_scale
4684
        )
4685

4686
4687
4688
4689
4690
        # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
        # the columns (pad tokens from k) are already zeroed out during softmax
        if "padding" in attn_mask_type:
            attention_probs = attention_probs.masked_fill(attention_mask, 0)

4691
4692
4693
4694
4695
4696
4697
4698
4699
4700
4701
4702
4703
4704
4705
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
4706
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
4707
4708

        # change view [b * np, sq, sk]
4709
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
4710
4711
4712
4713
4714
4715
4716

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

4717
        if qkv_format == "sbhd":
4718
4719
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
4720

4721
4722
4723
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

4724
        if qkv_format == "bshd":
4725
4726
4727
4728
4729
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
4730
4731
4732
4733
4734
4735

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
4736
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
4737
4738

    @staticmethod
4739
4740
4741
4742
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
4743
        value_layer: torch.Tensor,
4744
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
4745
4746
4747
4748
4749
4750
4751
4752
4753
4754
4755
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
4756
4757
4758
4759
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
4760
        dv: torch.Tensor,
4761
4762
4763
4764
4765
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

4766

4767
def get_qkv_layout(
4768
4769
4770
4771
4772
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
4773
    """Get qkv layout.
4774

4775
4776
4777
4778
4779
4780
4781
4782
4783
4784
4785
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
4786
        `d` head size, and `t` the total number of tokens in a batch, i.e.
4787
4788
4789
4790
4791
4792
4793
4794
4795
4796
4797
4798
4799
4800
4801
4802
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
4803

4804
4805
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
4806

4807
4808
4809
4810
4811
4812
4813
4814
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
4815
4816
        check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
            sv / v.shape[-1] for sv in v.stride()[:-1]
4817
        )
4818
4819
4820
4821

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
4822
        check_shapes_kv = shape[:-1] == v.shape[:-1]
4823
4824

        last_dim_size = q.shape[-1]
4825
4826
4827
        check_last_dim_offsets_qkv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
4828
        last_dim_size = k.shape[-1]
4829
4830
4831
        check_last_dim_offsets_kv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])
        )
4832
4833

        last_two_dims_size = q.shape[-1] * q.shape[-2]
4834
4835
4836
        check_last_two_dims_offsets_qkv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
4837
        last_two_dims_size = k.shape[-1] * k.shape[-2]
4838
4839
4840
        check_last_two_dims_offsets_kv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])
        )
4841

4842
4843
4844
4845
        if (
            check_ptrs_qkv
            and check_strides_qkv
            and check_shapes_qkv
4846
            and check_last_two_dims_offsets_qkv
4847
4848
            and not check_last_dim_offsets_qkv
        ):
4849
            # sb3hd, bs3hd, t3hd
4850
4851
4852
4853
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
        elif (
            check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv
        ):
4854
            # sbh3d, bsh3d, th3d
4855
4856
4857
4858
4859
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
        elif (
            check_ptrs_kv
            and check_strides_kv
            and check_shapes_kv
4860
            and check_last_two_dims_offsets_kv
4861
4862
            and not check_last_dim_offsets_kv
        ):
4863
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
4864
4865
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv:
4866
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
4867
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
4868
4869
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
4870
            qkv_layout = "_".join(list([qkv_format]) * 3)
4871
        else:
4872
            qkv_layout = "not_supported"
4873
4874
4875
4876

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
4877
    if qkv_layout == "not_supported":
4878
4879
4880
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
4881
    if qkv_layout == "not_supported":
4882
4883
        raise Exception("The provided qkv memory layout is not supported!")

4884
    return qkv_layout, q, k, v
4885

4886

4887
def check_set_window_size(
4888
4889
4890
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
4891
4892
4893
4894
4895
4896
4897
4898
    """Check if sliding window size is compliant with attention mask type.
    If not, set it to the appropriate size.

         attn_mask_type                              |   window_size
    -------------------------------------------------------------------------
    no_mask, padding, arbitrary                      | (-1, -1) or (>=0, >=0)
    causal, padding_causal                           | (-1,  0) or (>=0, 0)
    causal_bottom_right, padding_causal_bottom_right | (-1,  0) or (>=0, 0)
4899
    """
4900
    orig_window_size = window_size
4901
    if "causal" in attn_mask_type:
4902
        if orig_window_size is None:
4903
            window_size = (-1, 0)
4904
4905
4906
        elif orig_window_size == (-1, -1) or (
            orig_window_size[0] >= 0 and orig_window_size[1] != 0
        ):
4907
4908
4909
4910
            window_size = (orig_window_size[0], 0)
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
4911
        elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
4912
4913
4914
4915
            assert False, (
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
    elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
4916
4917
4918
        if orig_window_size is None:
            window_size = (-1, -1)
        elif orig_window_size == (-1, 0):
4919
            window_size = (-1, -1)
4920
4921
4922
            warnings.warn(
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
4923
        elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
4924
4925
4926
4927
4928
            assert False, (
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
    else:
        assert False, "Invalid attn_mask_type: " + attn_mask_type
4929
    return window_size
4930

4931

4932
class FlashAttention(torch.nn.Module):
4933
    """Dot product attention, using HazyResearch flash-attn package:
4934
    https://github.com/Dao-AILab/flash-attention
4935
4936
4937
4938
    """

    def __init__(
        self,
4939
        softmax_scale: float,
4940
4941
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
4942
4943
        attention_type: str = "self",
        layer_number: Optional[int] = None,
4944
        deterministic: bool = False,
4945
4946
4947
4948
4949
4950
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
4951
4952
4953
        assert (
            _flash_attn_version <= _flash_attn_max_version
        ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
4954

4955
        self.softmax_scale = softmax_scale
4956
4957
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
4958
4959
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
4960
        self.deterministic = deterministic
4961
4962
4963
4964
        self.logger = logging.getLogger("FlashAttention")
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
4965
4966
4967
4968
4969
4970

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4971
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4972
4973
4974
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
4975
4976
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
4977
        attn_mask_type: str = "causal",
4978
        window_size: Optional[Tuple[int, int]] = None,
4979
        alibi_slopes: Optional[torch.Tensor] = None,
4980
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
4981
        cp_global_ranks: List[int] = None,
4982
        cp_stream: torch.cuda.Stream = None,
4983
        cp_comm_type: str = "p2p",
4984
4985
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
4986
4987
4988
    ) -> torch.Tensor:
        """flash-attn fprop"""

4989
4990
4991
4992
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
4993
4994
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
4995
        ), "FlashAttention currently only supports CUDA tensors."
4996
4997
        assert (
            qkv_layout in QKVLayouts
4998
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
4999

5000
5001
5002
5003
5004
5005
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
5006
        context_parallel = cp_size > 1
5007

5008
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
5009

5010
5011
5012
5013
5014
5015
5016
5017
5018
5019
5020
5021
5022
        if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
            if qkv_format == "sbhd":
                # For now just 128, will make it more general in the future
                if (
                    query_layer.shape[-1] == 128
                    and query_layer.shape[0] * query_layer.shape[1] >= 512
                    and qkv_layout == "sbh3d"
                ):
                    query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                        query_layer, key_layer, value_layer
                    )
                else:
                    query_layer, key_layer, value_layer = [
5023
                        x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
5024
                    ]
5025
            if context_parallel:
5026
                query_layer, key_layer, value_layer = [
5027
5028
5029
5030
5031
                    x.contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        else:
            if qkv_format == "sbhd":
                query_layer._data, key_layer._data, value_layer._data = [
5032
                    x.transpose(0, 1)
5033
5034
                    for x in (query_layer._data, key_layer._data, value_layer._data)
                ]
5035
5036
5037
5038
                query_layer, key_layer, value_layer = [
                    Float8Tensor.make_like(x, data=x._data)
                    for x in (query_layer, key_layer, value_layer)
                ]
5039
            if context_parallel:
5040
5041
                query_layer._data, key_layer._data, value_layer._data = [
                    x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
5042
                ]
5043

5044
        batch_size = query_layer.shape[0]
5045

5046
        if qkv_format in ["sbhd", "bshd"]:
5047
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
5048
5049
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
5050
5051
5052

            if "padding" in attn_mask_type:
                assert not context_parallel, "Padding mask not supported with context parallelism!"
5053
5054
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
5055
                    x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
5056
5057
5058
5059
5060
5061
5062
                    for x in [query_layer, key_layer, value_layer]
                ]

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
5063
                    if cu_seqlens_q is None:
5064
5065
5066
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
5067
5068
5069
5070
5071
5072
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
5073
5074
                    )
                else:
5075
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
5076
5077
5078
5079
5080
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
5081
5082
5083
5084
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
5085
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
5086
            else:
5087
5088
5089
5090
5091
5092
5093
5094
5095
5096
5097
5098
5099
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
5100
5101
5102
5103
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
5104
5105
5106
5107
5108
5109
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
5110

5111
5112
5113
        if context_parallel and all(
            not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
        ):
5114
5115
5116
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
5117
            with self.attention_dropout_ctx():
5118
                output = attn_forward_func_with_cp(
5119
5120
5121
5122
5123
5124
5125
5126
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
5127
5128
                    cu_seqlens_q,
                    cu_seqlens_kv,
5129
                    self.attention_dropout if self.training else 0.0,
5130
5131
5132
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
5133
                    cp_comm_type,
5134
                    softmax_scale=self.softmax_scale,
5135
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
5136
                    attn_mask_type=attn_mask_type,
5137
                    deterministic=self.deterministic,
5138
                    window_size=window_size,
5139
5140
                )
        else:
5141
5142

            from .cpu_offload import CPUOffloadEnabled
5143

5144
5145
5146
5147
5148
5149
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

5150
            with self.attention_dropout_ctx():
5151
                fa_optional_forward_kwargs = {}
5152
5153
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
5154
5155
5156
5157
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
5158
5159
5160
5161
                fa_optional_forward_args_thd = []
                if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
                    func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
                else:
5162
5163
                    if _flash_attn_2_5_7_plus:
                        fa_optional_forward_kwargs["block_table"] = None
5164
5165
5166
5167
5168
5169
5170
5171
5172
5173
                    func = (
                        flash_attn_varlen_func
                        if not _use_flash_attn_3
                        else flash_attn_varlen_func_v3
                    )
                    fa_optional_forward_args_thd.append(cu_seqlens_q)
                    fa_optional_forward_args_thd.append(cu_seqlens_kv)
                    fa_optional_forward_args_thd.append(max_seqlen_q)
                    fa_optional_forward_args_thd.append(max_seqlen_kv)
                if _use_flash_attn_3:
5174
5175
5176
                    fa_3_optional_forward_kwargs = {}
                    fa_3_optional_forward_kwargs["window_size"] = window_size
                    fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
5177
5178
5179
5180
                    if fp8:
                        fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
                        activation_dtype = query_layer.dtype
                        torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
5181
5182
5183
5184
5185
5186
5187
5188
5189
5190
5191

                        def convert_to_torch_float8(tensor, dtype):
                            out = torch.Tensor().to(device=tensor.device, dtype=dtype)
                            out.set_(
                                tensor._data.untyped_storage(),
                                tensor._data.storage_offset(),
                                tensor._data.shape,
                                tensor._data.stride(),
                            )
                            return out

5192
5193
5194
5195
5196
5197
5198
5199
                        if fp8_meta["recipe"].fp8_mha:
                            assert all(
                                isinstance(x, Float8Tensor)
                                for x in [query_layer, key_layer, value_layer]
                            ), "q/k/v must be Float8Tensors for FP8 MHA."
                            fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
                        else:
                            query_layer, key_layer, value_layer = (
5200
5201
                                Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward)
                                for x in [query_layer, key_layer, value_layer]
5202
                            )
5203
5204
5205
5206
5207
5208
5209
5210
5211
5212
5213
5214
5215
5216
5217
5218
5219
5220
5221
5222
5223
5224
5225
5226
5227
5228
5229
                        fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv
                        fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv
                        fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv
                        query_layer, key_layer, value_layer = (
                            convert_to_torch_float8(x, torch_dtype)
                            for x in [query_layer, key_layer, value_layer]
                        )
                    try:
                        output, _ = func(
                            query_layer,
                            key_layer,
                            value_layer,
                            *fa_optional_forward_args_thd,
                            softmax_scale=self.softmax_scale,
                            causal="causal" in attn_mask_type,
                            **fa_3_optional_forward_kwargs,
                        )
                    except TypeError as e:
                        if _flash_attn_3_0_0_beta:
                            e.args = (
                                e.args[0]
                                + ". Please update your FlashAttention 3 (beta) installation as it "
                                + "may have added more supported arguments to its API. \n"
                                + _flash_attn_3_installation_steps,
                            ) + e.args[1:]
                        raise

5230
5231
5232
5233
5234
5235
5236
5237
5238
5239
5240
5241
5242
5243
5244
5245
5246
5247
5248
5249
5250
5251
5252
5253
5254
5255
                    if fp8 and fp8_meta["recipe"].fp8_mha:
                        output = cast_to_fp8(
                            output,
                            fp8_meta["scaling_fwd"],
                            META_O,
                            fp8_dtype_forward,
                        )
                        output = Float8Tensor(
                            data=output,
                            fp8_meta=fp8_meta,
                            fp8_meta_forward=True,
                            fp8_meta_index=META_O,
                            fp8_dtype=fp8_dtype_forward,
                            dtype=activation_dtype,
                        )
                else:
                    output = func(
                        query_layer,
                        key_layer,
                        value_layer,
                        *fa_optional_forward_args_thd,
                        self.attention_dropout if self.training else 0.0,
                        softmax_scale=self.softmax_scale,
                        causal="causal" in attn_mask_type,
                        **fa_optional_forward_kwargs,
                    )
5256

5257
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
5258
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
5259

5260
        if qkv_format == "sbhd":
5261
            # (bs)hd -> bs(hd) -> sb(hd)
5262
            if fp8 and fp8_meta["recipe"].fp8_mha:
5263
5264
5265
5266
5267
5268
                output = Float8Tensor.make_like(
                    output,
                    data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
                    .transpose(0, 1)
                    .contiguous(),
                )
5269
            else:
5270
                output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
5271
        elif qkv_format == "bshd":
5272
            # (bs)hd -> bs(hd)
5273
            output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
5274
        elif qkv_format == "thd":
5275
            # thd -> t(hd)
5276
            output = output.reshape(output.shape[0], -1)
5277

5278
        return output.contiguous()
5279

5280

5281
def _combine_tensors(
5282
5283
5284
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
5285
5286
5287
5288
5289
5290
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
5291
    new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
5292
    if isinstance(tensors[0], Float8Tensor):
5293
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
5294
5295
5296
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
5297
5298
5299
5300
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor)
5301
    else:
5302
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
5303
        combined_tensor.set_(
5304
5305
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
        )
5306
5307

    return combined_tensor
5308

5309

5310
5311
5312
5313
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
5314
5315
5316
5317
5318
    def forward(
        ctx,
        is_training,
        max_seqlen,
        cu_seqlens,
5319
        cu_seqlens_padded,
5320
5321
5322
5323
5324
5325
5326
5327
5328
        qkv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
5329
        window_size,
5330
5331
5332
5333
5334
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
5335
        deterministic,
5336
    ):
5337
5338
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
5339
        if fp8:
5340
5341
            is_input_fp8 = isinstance(qkv, Float8Tensor)
            if is_input_fp8:
5342
5343
5344
5345
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
5346
            qkv_group = len(qkv_layout.split("_"))
5347
5348
5349
5350
            assert (
                qkv_group == 1
            ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}."
            if is_input_fp8:
5351
5352
5353
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
5354
5355
5356
                qkv_fp8 = cast_to_fp8(
                    qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(qkv.shape)
5357
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
5358
5359
5360
5361
5362
5363
5364
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
5365
                cu_seqlens_padded,
5366
5367
5368
5369
5370
5371
5372
5373
5374
5375
5376
5377
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
5378
5379
5380
5381
5382
5383
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5384
                window_size,
5385
5386
                rng_gen,
            )
5387
            if is_output_fp8:
5388
5389
                out_ret = Float8Tensor(
                    data=out_fp8,
5390
5391
5392
5393
5394
5395
5396
5397
5398
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
5399
5400
5401
5402
5403
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
5404
            out_save = out_ret
5405
5406
5407
5408
5409
5410
5411
5412
5413
5414
5415
5416
5417
5418
5419
5420
5421
5422
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                if is_input_fp8:
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                    qkv = cast_from_fp8(
                        qkv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[qkv.dtype],
                    ).view(qkv.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                        fp8_meta["scaling_fwd"],
                        META_O,
                        fp8_dtype_forward,
                        qkv_dtype,
                    ).view(out_fp8.shape)
5423
5424
5425
            fp8_tensors = (
                qkv_fp8,
                out_fp8,
5426
                fp8_meta["scaling_fwd"].scale.clone(),
5427
5428
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
5429
5430
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
5431
5432
5433
5434
5435
5436
5437
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
5438
                cu_seqlens_padded,
5439
5440
5441
5442
5443
5444
5445
5446
5447
5448
5449
5450
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
5451
5452
5453
5454
5455
5456
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5457
                window_size,
5458
5459
                rng_gen,
            )
5460
5461
5462
5463
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
5464
5465
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
5466
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
5467
        ctx.save_for_backward(
5468
            *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
5469
        )
5470
        ctx.fp8_meta = fp8_meta
5471
5472
5473
5474
5475
5476
5477
5478
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
5479
        ctx.window_size = window_size
5480
        ctx.fused_attention_backend = (
5481
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
5482
        )
5483
        ctx.use_FAv2_bwd = use_FAv2_bwd
5484
        ctx.deterministic = deterministic
5485

5486
        return out_ret
5487
5488
5489

    @staticmethod
    def backward(ctx, d_out):
5490
        if ctx.is_output_fp8:
5491
5492
5493
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
5494
5495
5496
            d_out_f8tensor = d_out
            d_out = d_out._data

5497
        d_out = d_out.contiguous()
5498
5499
5500
5501
        (
            qkv,
            out,
            cu_seqlens,
5502
            cu_seqlens_padded,
5503
5504
5505
5506
5507
5508
            qkv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
5509
5510
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
5511
        if ctx.use_FAv2_bwd:
5512
            softmax_lse, rng_state = aux_ctx_tensors
5513
5514
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
5515
5516
5517
            d_out, q, k, v, out = [
                maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
            ]
5518
            flash_attn_cuda_bwd(
5519
5520
5521
5522
5523
5524
5525
5526
5527
5528
5529
5530
5531
5532
5533
5534
5535
5536
5537
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dqkv[:, 0],
                dqkv[:, 1],
                dqkv[:, 2],
                cu_seqlens,
                cu_seqlens,
                ctx.max_seqlen,
                ctx.max_seqlen,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
5538
            )
5539
            dqkv = dqkv[..., : d_out.shape[-1]]
5540
        else:
5541
5542
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
5543
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
5544
                    fp8_dtype_backward = get_fp8_te_dtype(
5545
5546
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
5547
                    if ctx.is_output_fp8:
5548
                        d_out_fp8 = d_out
5549
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
5550
5551
5552
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
5553
5554
5555
5556
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
5557
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
5558
5559
5560
5561
5562
5563
5564
5565
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
5566
                        ctx.fused_attention_backend,
5567
                        cu_seqlens_padded,
5568
5569
5570
5571
5572
5573
5574
5575
5576
5577
5578
5579
5580
5581
5582
5583
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5584
5585
                        ctx.window_size,
                        ctx.deterministic,
5586
                    )
5587
                    if ctx.is_input_fp8:
5588
5589
                        dqkv = Float8Tensor(
                            data=dqkv_fp8,
5590
5591
5592
5593
5594
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
5595
                        )
5596
                    else:
5597
5598
5599
5600
5601
5602
5603
5604
5605
5606
                        dqkv_c_fp8 = dqkv_fp8.view(
                            -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                        )
                        dqkv = cast_from_fp8(
                            dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dqkv_fp8.shape)
5607
5608
5609
5610
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
5611
5612
5613
5614
5615
5616
5617
5618
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
5619
                        ctx.fused_attention_backend,
5620
                        cu_seqlens_padded,
5621
5622
5623
5624
5625
5626
5627
5628
5629
5630
5631
5632
5633
5634
5635
5636
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
5637
5638
                        ctx.window_size,
                        ctx.deterministic,
5639
                    )
5640

5641
5642
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
5643
5644
5645
5646
5647
5648
5649
5650
5651
5652
5653
5654
5655
5656
5657
5658
5659
5660
5661
5662
5663
            return (
                None,
                None,
                None,
                None,
                dqkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
5664
5665
                None,
                None,
5666
            )
5667
        # else, return (dqkv, dbias)
5668
5669
5670
5671
5672
5673
5674
5675
5676
5677
5678
5679
5680
5681
5682
5683
5684
5685
5686
5687
5688
        return (
            None,
            None,
            None,
            None,
            dqkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
5689
5690
            None,
            None,
5691
        )
5692

5693

5694
5695
5696
5697
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
5698
5699
5700
5701
5702
5703
5704
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
5705
5706
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
5707
5708
5709
5710
5711
5712
5713
5714
5715
5716
        q,
        kv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
5717
        window_size,
5718
5719
5720
5721
5722
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
5723
        deterministic,
5724
    ):
5725
5726
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
5727
        if fp8:
5728
5729
5730
            assert isinstance(kv, q.__class__), "q and kv must have the same type."
            is_input_fp8 = isinstance(q, Float8Tensor)
            if is_input_fp8:
5731
5732
5733
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
5734
            if is_input_fp8:
5735
5736
5737
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
5738
5739
                qkv_group = len(qkv_layout.split("_"))
                assert qkv_group == 2, (
5740
5741
                    "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, "
                    f"but found {qkv_layout}."
5742
5743
5744
5745
                )
                q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
                    q.shape
                )
5746
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
5747
5748
5749
                kv_fp8 = cast_to_fp8(
                    kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(kv.shape)
5750
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
5751
5752
5753
5754
5755
5756
5757
5758
5759
5760
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                kv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
5761
5762
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
5763
5764
5765
5766
5767
5768
5769
5770
5771
5772
5773
5774
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
5775
5776
5777
5778
5779
5780
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5781
                window_size,
5782
5783
                rng_gen,
            )
5784
            if is_output_fp8:
5785
5786
                out_ret = Float8Tensor(
                    data=out_fp8,
5787
5788
5789
5790
5791
5792
5793
5794
5795
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
5796
5797
5798
5799
5800
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
5801
            out_save = out_ret
5802
5803
5804
5805
5806
5807
5808
5809
5810
5811
5812
5813
5814
5815
5816
5817
5818
5819
5820
5821
5822
5823
5824
5825
5826
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                if is_input_fp8:
                    q = cast_from_fp8(
                        q._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                    kv = cast_from_fp8(
                        kv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[kv.dtype],
                    ).view(kv.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                        fp8_meta["scaling_fwd"],
                        META_O,
                        fp8_dtype_forward,
                        qkv_dtype,
                    ).view(out_fp8.shape)
5827
5828
5829
5830
            fp8_tensors = (
                q_fp8,
                kv_fp8,
                out_fp8,
5831
                fp8_meta["scaling_fwd"].scale.clone(),
5832
5833
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
5834
5835
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
5836
5837
5838
5839
5840
5841
5842
5843
5844
5845
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                kv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
5846
5847
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
5848
5849
5850
5851
5852
5853
5854
5855
5856
5857
5858
5859
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
5860
5861
5862
5863
5864
5865
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
5866
                window_size,
5867
5868
                rng_gen,
            )
5869
5870
5871
5872
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
5873
5874
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
5875
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
5876
5877
5878
5879
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
5880
5881
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
5882
5883
5884
            *fp8_tensors,
            *aux_ctx_tensors,
        )
5885
        ctx.fp8_meta = fp8_meta
5886
5887
5888
5889
5890
5891
5892
5893
5894
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
5895
        ctx.window_size = window_size
5896
        ctx.fused_attention_backend = (
5897
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
5898
        )
5899
        ctx.use_FAv2_bwd = use_FAv2_bwd
5900
        ctx.deterministic = deterministic
5901

5902
        return out_ret
5903
5904
5905

    @staticmethod
    def backward(ctx, d_out):
5906
        if ctx.is_output_fp8:
5907
5908
5909
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
5910
5911
5912
            d_out_f8tensor = d_out
            d_out = d_out._data

5913
        d_out = d_out.contiguous()
5914
5915
5916
5917
5918
5919
        (
            q,
            kv,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
5920
5921
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
5922
5923
5924
5925
5926
5927
5928
            q_fp8,
            kv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
5929
5930
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
5931
        if ctx.use_FAv2_bwd:
5932
            softmax_lse, rng_state = aux_ctx_tensors
5933
5934
5935
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
5936
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
5937
            flash_attn_cuda_bwd(
5938
5939
5940
5941
5942
5943
5944
5945
5946
5947
5948
5949
5950
5951
5952
5953
5954
5955
5956
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dkv[:, 0],
                dkv[:, 1],
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
5957
            )
5958
5959
            dq = dq[..., : d_out.shape[-1]]
            dkv = dkv[..., : d_out.shape[-1]]
5960
        else:
5961
5962
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
5963
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
5964
                    fp8_dtype_backward = get_fp8_te_dtype(
5965
5966
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
5967
                    if ctx.is_output_fp8:
5968
                        d_out_fp8 = d_out
5969
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
5970
5971
5972
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
5973
5974
5975
5976
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
5977
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
5978
5979
5980
5981
5982
5983
5984
5985
5986
5987
5988
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        kv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
5989
                        ctx.fused_attention_backend,
5990
5991
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
5992
5993
5994
5995
5996
5997
5998
5999
6000
6001
6002
6003
6004
6005
6006
6007
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6008
6009
                        ctx.window_size,
                        ctx.deterministic,
6010
                    )
6011
                    if ctx.is_input_fp8:
6012
6013
                        dq = Float8Tensor(
                            data=dq_fp8,
6014
6015
6016
6017
6018
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6019
6020
6021
                        )
                        dkv = Float8Tensor(
                            data=dkv_fp8,
6022
6023
6024
6025
6026
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6027
                        )
6028
6029
6030
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6031
6032
6033
6034
6035
6036
6037
6038
6039
6040
6041
6042
6043
6044
6045
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(
                            -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                        )
                        dkv = cast_from_fp8(
                            dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dkv_fp8.shape)
6046
6047
6048
6049
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
6050
6051
6052
6053
6054
6055
6056
6057
6058
6059
6060
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        kv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
6061
                        ctx.fused_attention_backend,
6062
6063
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6064
6065
6066
6067
6068
6069
6070
6071
6072
6073
6074
6075
6076
6077
6078
6079
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6080
6081
                        ctx.window_size,
                        ctx.deterministic,
6082
                    )
6083

6084
6085
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
6086
6087
6088
6089
6090
6091
6092
6093
6094
6095
6096
6097
6098
6099
6100
6101
6102
6103
6104
6105
6106
6107
6108
6109
6110
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
6111
6112
                None,
                None,
6113
            )
6114
        # else, return (dqkv, dbias)
6115
6116
6117
6118
6119
6120
6121
6122
6123
6124
6125
6126
6127
6128
6129
6130
6131
6132
6133
6134
6135
6136
6137
6138
6139
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
6140
6141
            None,
            None,
6142
6143
        )

6144

6145
6146
6147
6148
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
6149
6150
6151
6152
6153
6154
6155
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
6156
6157
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
6158
6159
6160
6161
6162
6163
6164
6165
6166
6167
6168
        q,
        k,
        v,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
6169
        window_size,
6170
6171
6172
6173
6174
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
6175
        deterministic,
6176
    ):
6177
6178
        is_input_fp8 = False
        is_output_fp8 = fp8_meta["recipe"].fp8_mha
6179
6180
6181
        if fp8:
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
6182
6183
6184
6185
6186
            assert isinstance(k, q.__class__) and isinstance(
                v, q.__class__
            ), "q, k, and v must have the same type."
            is_input_fp8 = isinstance(q, Float8Tensor)
            if is_input_fp8:
6187
6188
6189
6190
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6191
                qkv_group = len(qkv_layout.split("_"))
6192
                if qkv_group == 1:
6193
6194
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
6195
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
6196
6197
6198
6199
                    qkv_fp8 = cast_to_fp8(
                        qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1])
6200
6201
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
6202
6203
6204
6205
6206
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
6207
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
6208
6209
6210
6211
                    kv_fp8 = cast_to_fp8(
                        kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1])
6212
6213
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
6214
6215
6216
6217
6218
6219
6220
6221
6222
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    k_fp8 = cast_to_fp8(
                        k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(k.shape)
                    v_fp8 = cast_to_fp8(
                        v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(v.shape)
6223
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
6224
6225
6226
6227
6228
6229
6230
6231
6232
6233
6234
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
6235
6236
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6237
6238
6239
6240
6241
6242
6243
6244
6245
6246
6247
6248
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_qkv
                META_QKV,  # d_scale_qkv_offset
                fp8_meta["scaling_fwd"].scale_inv,  # d_scale_s
                META_S,  # d_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_s
                META_S,  # q_scale_s_offset
                fp8_meta["scaling_fwd"].scale,  # q_scale_o
                META_O,  # q_scale_o_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_s
                META_S,  # amax_s_offset
                fp8_meta["scaling_fwd"].amax_history,  # amax_o
                META_O,  # amax_o_offset
6249
6250
6251
6252
6253
6254
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6255
                window_size,
6256
6257
                rng_gen,
            )
6258
            if is_output_fp8:
6259
6260
                out_ret = Float8Tensor(
                    data=out_fp8,
6261
6262
6263
6264
6265
6266
6267
6268
6269
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
6270
6271
6272
6273
6274
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
6275
6276
            out_save = out_ret

6277
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
6278
                # 1: qkv packed, 2: kv packed, 3: qkv separate
6279
6280
6281
6282
6283
6284
6285
6286
6287
6288
6289
6290
6291
6292
6293
6294
6295
6296
6297
6298
6299
6300
6301
6302
6303
6304
6305
6306
6307
6308
6309
6310
6311
6312
6313
6314
6315
6316
6317
6318
6319
6320
6321
6322
6323
6324
6325
6326
6327
6328
6329
6330
6331
6332
6333
6334
6335
6336
6337
6338
                if is_input_fp8:
                    qkv_group = len(qkv_layout.split("_"))
                    if qkv_group == 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                        qkv_no_fp8 = cast_from_fp8(
                            qkv_c._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[qkv.dtype],
                        ).view(qkv.shape)
                        q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1])
                        q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                    if qkv_group == 2:
                        q = cast_from_fp8(
                            q._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[q.dtype],
                        ).view(q.shape)
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                        kv_no_fp8 = cast_from_fp8(
                            kv_c._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[kv.dtype],
                        ).view(kv.shape)
                        k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1])
                        k, v = [x.squeeze(dim) for x in [k, v]]
                    if qkv_group == 3:
                        q = cast_from_fp8(
                            q._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[q.dtype],
                        ).view(q.shape)
                        k = cast_from_fp8(
                            k._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[k.dtype],
                        ).view(k.shape)
                        v = cast_from_fp8(
                            v._data,
                            fp8_meta["scaling_fwd"],
                            META_QKV,
                            fp8_dtype_forward,
                            TE_DType[v.dtype],
                        ).view(v.shape)
                if is_output_fp8:
                    out_save = cast_from_fp8(
                        out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
6339
                        fp8_meta["scaling_fwd"],
6340
                        META_O,
6341
                        fp8_dtype_forward,
6342
6343
                        qkv_dtype,
                    ).view(out_fp8.shape)
6344
6345
6346
6347
6348
6349

            fp8_tensors = (
                q_fp8,
                k_fp8,
                v_fp8,
                out_fp8,
6350
                fp8_meta["scaling_fwd"].scale.clone(),
6351
6352
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
6353
6354
        else:
            out_ret, aux_ctx_tensors = fused_attn_fwd(
6355
6356
6357
6358
6359
6360
6361
6362
6363
6364
6365
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
6366
6367
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
6368
6369
6370
6371
6372
6373
6374
6375
6376
6377
6378
6379
                None,  # d_scale_qkv
                0,  # d_scale_qkv_offset
                None,  # d_scale_s
                0,  # d_scale_s_offset
                None,  # q_scale_s
                0,  # q_scale_s_offset
                None,  # q_scale_o
                0,  # q_scale_o_offset
                None,  # amax_s
                0,  # amax_s_offset
                None,  # amax_o
                0,  # amax_o_offset
6380
6381
6382
6383
6384
6385
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
6386
                window_size,
6387
6388
                rng_gen,
            )
6389
6390
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
6391

6392
6393
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

6394
        from .cpu_offload import CPUOffloadEnabled
6395

6396
        if CPUOffloadEnabled:
6397
6398
6399
6400
6401
6402
6403
            if ctx.fp8:
                tensor_list = fp8_tensors
            else:
                tensor_list = [q, k, v, out_save]

            tensor_list.extend(aux_ctx_tensors)

6404
            qkv_layout = "sbhd_sbhd_sbhd"
6405
6406
6407
6408
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

6409
6410
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
6411
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
6412
6413
6414
6415
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
6416
6417
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6418
6419
6420
            *fp8_tensors,
            *aux_ctx_tensors,
        )
6421
        ctx.fp8_meta = fp8_meta
6422
6423
6424
6425
6426
6427
6428
6429
6430
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
6431
        ctx.window_size = window_size
6432
        ctx.fused_attention_backend = (
6433
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
6434
        )
6435
        ctx.use_FAv2_bwd = use_FAv2_bwd
6436
        ctx.deterministic = deterministic
6437

6438
        return out_ret
6439
6440
6441

    @staticmethod
    def backward(ctx, d_out):
6442
        if ctx.is_output_fp8:
6443
6444
6445
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
6446
6447
6448
            d_out_f8tensor = d_out
            d_out = d_out._data

6449
        d_out = d_out.contiguous()
6450
6451
6452
6453
6454
6455
6456
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
6457
6458
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
6459
6460
6461
6462
6463
6464
6465
6466
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
6467
6468
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
6469
        if ctx.use_FAv2_bwd:
6470
            softmax_lse, rng_state = aux_ctx_tensors
6471
6472
6473
6474
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
6475
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
6476
            flash_attn_cuda_bwd(
6477
6478
6479
6480
6481
6482
6483
6484
6485
6486
6487
6488
6489
6490
6491
6492
6493
6494
6495
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
6496
            )
6497
6498
6499
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
6500
        else:
6501
6502
6503
6504
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
6505
6506
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
6507
                    if ctx.is_output_fp8:
6508
                        d_out_fp8 = d_out
6509
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
6510
6511
6512
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
6513
6514
6515
6516
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
6517
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
6518
6519
6520
6521
6522
6523
6524
6525
6526
6527
6528
6529
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
6530
                        ctx.fused_attention_backend,
6531
6532
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6533
6534
6535
6536
6537
6538
6539
6540
6541
6542
6543
6544
6545
6546
6547
6548
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6549
6550
                        ctx.window_size,
                        ctx.deterministic,
6551
                    )
6552

6553
                    if ctx.is_input_fp8:
6554
6555
                        dq = Float8Tensor(
                            data=dq_fp8,
6556
6557
6558
6559
6560
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6561
6562
6563
                        )
                        dk = Float8Tensor(
                            data=dk_fp8,
6564
6565
6566
6567
6568
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6569
6570
6571
                        )
                        dv = Float8Tensor(
                            data=dv_fp8,
6572
6573
6574
6575
6576
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
6577
                        )
6578
                    else:
6579
                        qkv_group = len(ctx.qkv_layout.split("_"))
6580
                        if qkv_group == 1:
6581
6582
6583
6584
6585
6586
6587
6588
6589
6590
6591
6592
6593
                            dim = ctx.qkv_layout.find("3")
                            dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(
                                -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                            )
                            dqkv = cast_from_fp8(
                                dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1])
6594
6595
6596
6597
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6598
6599
6600
6601
6602
6603
6604
6605
6606
6607
6608
6609
6610
6611
6612
6613
6614
6615
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
                            dkv = cast_from_fp8(
                                dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1])
6616
6617
6618
6619
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
6620
6621
6622
6623
6624
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
6625
6626
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
6627
6628
6629
6630
6631
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dk_fp8.shape)
6632
6633
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
6634
6635
6636
6637
6638
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dv_fp8.shape)
6639
6640
6641
6642
                else:
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
6643
6644
6645
6646
6647
6648
6649
6650
6651
6652
6653
6654
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
6655
                        ctx.fused_attention_backend,
6656
6657
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
6658
6659
6660
6661
6662
6663
6664
6665
6666
6667
6668
6669
6670
6671
6672
6673
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
6674
6675
                        ctx.window_size,
                        ctx.deterministic,
6676
                    )
6677

6678
6679
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
6680
6681
6682
6683
6684
6685
6686
6687
6688
6689
6690
6691
6692
6693
6694
6695
6696
6697
6698
6699
6700
6701
6702
6703
6704
6705
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
6706
6707
                None,
                None,
6708
            )
6709
        # else, return (dqkv, dbias)
6710
6711
6712
6713
6714
6715
6716
6717
6718
6719
6720
6721
6722
6723
6724
6725
6726
6727
6728
6729
6730
6731
6732
6733
6734
6735
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
6736
6737
            None,
            None,
6738
        )
6739

6740

6741
class FusedAttention(torch.nn.Module):
6742
6743
6744
6745
6746
6747
6748
6749
6750
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

6751
6752
6753
6754
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
6755
    | attn_type     | self/cross              | self/cross                     |
6756
    | qkv_layout    |                         |                                |
6757
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
6758
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
6759
6760
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
6761
6762
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
6763
    | dropout       | yes                     | yes                            |
6764
6765
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
6766
    | output dtype  | fp16/bf16               | fp16/bf16                      |
6767
6768
6769
6770
    """

    def __init__(
        self,
6771
        softmax_scale: float,
6772
6773
6774
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
6775
6776
        layer_number: Optional[int] = None,
        deterministic: bool = False,
6777
6778
6779
    ) -> None:
        super().__init__()

6780
        self.softmax_scale = softmax_scale
6781
6782
6783
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
6784
6785
6786
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
6787
        self.layer_number = 1 if layer_number is None else layer_number
6788
        self.deterministic = deterministic
6789

6790
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
6791
6792
            """
            Temporarily remove fused_attention._extra_state as a missing key
6793
6794
6795
6796
            or an unexpected key when loading TransformerEngine checkpoints.
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
            phased out in TransformerEngine 2.0.
6797
6798
            """
            for key in incompatible_keys.missing_keys:
6799
                if "fused_attention._extra_state" in key:
6800
                    incompatible_keys.missing_keys.remove(key)
6801
6802
6803
6804
6805
6806
6807
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
6808

6809
6810
        self.register_load_state_dict_post_hook(remove_extra_states_check)

6811
    @no_torch_dynamo()
6812
6813
6814
6815
6816
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
6817
6818
6819
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
6820
6821
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
6822
6823
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
6824
        attn_mask_type: str = "causal",
6825
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6826
        window_size: Optional[Tuple[int, int]] = None,
6827
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
6828
6829
6830
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
6831
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
6832
6833
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
6834
        cp_comm_type: str = "p2p",
6835
6836
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
6837
6838
    ) -> torch.Tensor:
        """fused attention fprop"""
6839
6840
6841
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
6842
6843
6844
6845
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
6846
6847
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
6848
        ), "FusedAttention only supports CUDA tensors."
6849
6850
        assert (
            qkv_layout in QKVLayouts
6851
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
6852

6853
6854
6855
6856
6857
6858
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
6859
        context_parallel = cp_size > 1
6860

6861
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
6862

6863
6864
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
6865
                batch_size, max_seqlen_q, max_seqlen_kv = (
6866
6867
6868
6869
6870
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
6871
                batch_size, max_seqlen_q, max_seqlen_kv = (
6872
6873
6874
6875
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
6876
6877
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
6878
            if "padding" in attn_mask_type:
6879
6880
                assert not context_parallel, "Padding mask not supported with context parallelism!"

6881
6882
6883
6884
6885
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
6886
                    if self.attention_type == "self":
6887
6888
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
6889
                    else:
6890
6891
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
6892
            else:
6893
6894
6895
6896
6897
6898
6899
6900
6901
6902
6903
6904
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
6905
6906
6907
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
6908
6909
6910
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
6911
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
6912
6913
6914
6915

        if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
6916
6917
6918

        qkv_dtype = TE_DType[query_layer.dtype]

6919
6920
6921
6922
6923
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
6924

6925
6926
6927
6928
6929
6930
6931
6932
6933
6934
6935
        if fp8:
            assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                " is required for FP8 attention!"
            )
            assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
            assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
                "Amax reduction across TP+CP group is necessary when using context parallelism with"
                " FP8!"
            )

6936
        if context_parallel:
6937
            assert (
6938
6939
                fp8
                or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
6940
6941
6942
6943
6944
6945
6946
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
6947
6948
6949
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
6950
6951
6952
6953
6954
6955
6956
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
6957
6958
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
6959
                    self.attention_dropout if self.training else 0.0,
6960
6961
6962
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
6963
                    cp_comm_type,
6964
                    softmax_scale=self.softmax_scale,
6965
                    qkv_format=qkv_format,
6966
                    attn_mask_type=attn_mask_type,
6967
6968
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
6969
                    deterministic=self.deterministic,
6970
                    use_fused_attention=True,
6971
                    window_size=window_size,
6972
6973
                    fp8=fp8,
                    fp8_meta=fp8_meta,
6974
6975
                )
        else:
6976
6977
6978
6979
6980
6981
6982
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
6983
6984
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
6985
6986
6987
6988
6989
6990
6991
6992
6993
6994
6995
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
6996
                    window_size,
6997
6998
6999
7000
7001
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
7002
                    self.deterministic,
7003
                )
7004

7005
7006
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
7007
7008


7009
class DotProductAttention(TransformerEngineBaseModule):
7010
7011
7012
7013
7014
7015
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

7016
        Argument :attr:`attention_mask` in the `forward` call is only used when
7017
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
7018
7019
7020

    .. warning::

7021
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
7022
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
7023
7024
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
7025
7026
7027
7028
7029

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
7030
7031
7032
    kv_channels : Union[int, Tuple[int, int]]
                the head size in key and value tensors. If the same, :attr:`kv_channels` can be
                an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
7033
7034
7035
7036
7037
7038
7039
7040
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
7041
7042
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
7043
    attn_mask_type: str, default = `causal`
7044
                   type of attention mask passed into softmax operation, options are "`no_mask`",
7045
7046
7047
7048
7049
7050
7051
7052
7053
7054
7055
7056
7057
7058
7059
7060
7061
7062
7063
7064
7065
7066
7067
7068
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
                   "`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
7069
7070
7071
7072
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
7073
7074
7075
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
7076
                be overridden by :attr:`window_size` in `forward` as well.
7077
7078
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
7079
7080
7081
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
7082
7083
7084
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
7085
               `h` the number of heads, `d` head size, and `t` the total number of tokens
7086
7087
7088
7089
7090
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
7091
               For that, please use `get_qkv_layout` to gain the layout information.
7092
7093
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
7094
                `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
7095
7096
7097
7098
7099
7100
7101
7102
7103

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
7104
    cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
7105
              context parallel process group.
7106
7107
7108
              ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
              List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
              and cp_group[1] are for a2a and p2p communications respectively.
7109
7110
7111
7112
7113
7114
7115
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
7116
    cp_comm_type : str, default = `p2p`
7117
                  inter-gpu communication type for context parallelism.
7118
                  Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
7119
7120
7121
7122
7123
7124
                  "p2p": Exchange KV chunks with P2P communications in ring topology.
                         P2P is async and can be overlapped with attention compute.
                  "all_gather": All-gather to get full sequence of KV before attention.
                                The all-gather is not async, and cannot be overlapped.
                  "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                         group, and gather to get full sequence of QKV.
7125
7126
7127
                  "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                  across each CP sub-group (e.g., via NVLink), then exchanging KV with
                  p2p between sub-groups (e.g., via IBLink).
7128
7129
7130
7131
7132
    """

    def __init__(
        self,
        num_attention_heads: int,
7133
        kv_channels: Union[int, Tuple[int, int]],
7134
        num_gqa_groups: Optional[int] = None,
7135
        attention_dropout: float = 0.0,
7136
        qkv_format: str = "sbhd",
7137
        attn_mask_type: str = "causal",
7138
        window_size: Optional[Tuple[int, int]] = None,
7139
7140
7141
7142
7143
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
7144
        attention_type: str = "self",
7145
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
7146
        cp_global_ranks: List[int] = None,
7147
        cp_stream: torch.cuda.Stream = None,
7148
        cp_comm_type: str = "p2p",
7149
        softmax_scale: Optional[float] = None,
7150
7151
7152
    ) -> None:
        super().__init__()

7153
        self.logger = logging.getLogger("DotProductAttention")
7154
7155
7156
        self.logger.setLevel(_log_level)
        if not self.logger.hasHandlers():
            self.logger.addHandler(_stream_handler)
7157
        self.qkv_format = qkv_format
7158
        attn_mask_type = attn_mask_type.replace(",", "_")
7159
7160
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
7161
        self.attn_mask_type = attn_mask_type
7162
        self.window_size = check_set_window_size(attn_mask_type, window_size)
7163
7164
7165
7166
7167
7168
7169
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
7170
        self.get_rng_state_tracker = get_rng_state_tracker
7171
        self.num_attention_heads = num_attention_heads
7172
        self.layer_number = 1 if layer_number is None else layer_number
7173
7174
7175
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
7176
        self.cp_comm_type = cp_comm_type
7177

7178
7179
7180
7181
7182
7183
        self.hidden_size_per_attention_head_k = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[0]
        )
        self.hidden_size_per_attention_head_v = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[1]
        )
7184

7185
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
7186
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
7187

7188
7189
7190
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
7191

7192
        self.rng_states_tracker = None
7193
7194
7195
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
7196
7197
7198
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
7199

7200
        if softmax_scale is None:
7201
7202
7203
            softmax_scale = 1.0 / math.sqrt(
                kv_channels if isinstance(kv_channels, int) else kv_channels[0]
            )
7204

7205
7206
7207
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
7208
        )
7209
7210
7211
7212
7213
7214
7215
7216
7217
7218
7219
7220
7221
7222
7223
7224
7225
7226
7227
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
7228

7229
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
7230
7231
7232
7233

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

7234
7235
7236
7237
7238
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

7239
7240
7241
7242
7243
7244
7245
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
7246

7247
        # Instantiating three types since use of flash-attn and FusedAttention
7248
        # might be ruled out due to forward inputs.
7249
7250
7251
7252
7253
7254
7255
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
7256

7257
        self.unfused_attention = UnfusedDotProductAttention(
7258
7259
7260
7261
            softmax_scale,
            attention_type=attention_type,
            **attn_kwargs,
            layer_number=layer_number,
7262
        )
7263

7264
7265
7266
7267
7268
7269
7270
7271
7272
7273
7274
7275
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
            when loading older TransformerEngine checkpoints. Will phase out
            this hook in TransformerEngine 2.0.
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

7276
7277
7278
7279
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
7280
        **forward_kwargs: Dict[str, Any],
7281
7282
7283
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

7284
7285
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
7286
7287
7288

        hidden_states = checkpoint(
            custom_forward,
7289
7290
7291
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
7292
            *forward_args,
7293
            **forward_kwargs,
7294
7295
7296
7297
        )

        return hidden_states

7298
7299
    def set_context_parallel_group(
        self,
7300
        cp_group: Union[dist_group_type, List[dist_group_type], None],
7301
7302
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
7303
        cp_comm_type: str = "p2p",
7304
    ) -> None:
7305
7306
7307
7308
7309
7310
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
7311
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
7312
                  context parallel process group.
7313
7314
7315
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
7316
7317
7318
7319
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
7320
        cp_comm_type : str, default = `p2p`
7321
                      inter-gpu communication type for context parallelism.
7322
                      Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
7323
7324
7325
7326
7327
7328
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
7329
7330
7331
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
7332
        """
7333
7334
7335
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
7336
        self.cp_comm_type = cp_comm_type
7337

7338
    @no_torch_dynamo(recursive=False)
7339
7340
7341
7342
7343
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
7344
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
7345
7346
7347
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
7348
7349
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
7350
7351
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
7352
        attn_mask_type: Optional[str] = None,
7353
        window_size: Optional[Tuple[int, int]] = None,
7354
        checkpoint_core_attention: bool = False,
7355
7356
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
7357
        alibi_slopes: Optional[torch.Tensor] = None,
7358
        fast_zero_fill: bool = True,
7359
        inference_params: Optional[InferenceParams] = None,
7360
        is_first_microbatch: Optional[bool] = None,
7361
7362
7363
7364
7365
7366
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

7367
7368
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
7369

7370
7371
        .. note::

7372
7373
7374
7375
7376
7377
7378
7379
7380
7381
7382
7383
7384
7385
7386
7387
7388
7389
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
7390
7391
7392
7393
7394
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
7395

7396
7397
7398
7399
7400
7401
7402
7403
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
7404
7405
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
7406
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
7407
7408
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
7409
7410
7411
7412
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
7413
7414
7415
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
7416
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
7417
7418
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
7419
7420
7421
7422
7423
7424
7425
7426
7427
7428
7429
7430
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
7431
7432
7433
7434
7435
7436
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
7437
7438
7439
7440
7441
7442
7443
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
7444
        window_size: Optional[Tuple[int, int]], default = `None`
7445
                    Sliding window size for local attention.
7446
7447
7448
7449
7450
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
7451
        core_attention_bias_type: str, default = `no_bias`
7452
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
7453
        core_attention_bias: Optional[torch.Tensor], default = `None`
7454
7455
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
7456
7457
7458
7459
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
7460
        fast_zero_fill: bool, default = `True`
7461
                    Whether to use the fast path to set output tensors to 0 or not.
7462
7463
7464
7465
7466
7467
7468
7469
7470
7471
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
7472
7473
7474
7475
7476
7477
7478
7479
7480
7481
7482
7483
7484
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
7485
        """
7486
7487
7488
7489
7490
7491
7492
7493
7494
7495
7496
        with self.prepare_forward(
            query_layer,
            is_first_microbatch,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:

            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
7497
                        self.logger.warning(
7498
7499
7500
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
7501
7502
7503
7504
7505
7506
7507
7508
7509
7510
7511

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
7512

7513
7514
7515
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
7516
7517
7518
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
7519
7520
7521
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
7522
7523
7524
7525
7526
7527
7528
7529
            assert (
                key_layer.shape[-1] == self.hidden_size_per_attention_head_k
            ), f"Keys have head_dim = {key_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
            assert (
                value_layer.shape[-1] == self.hidden_size_per_attention_head_v
            ), f"Values have head_dim = {value_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
7530

7531
7532
7533
7534
7535
7536
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
7537
            assert (
7538
7539
7540
7541
7542
7543
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
7544

7545
7546
7547
7548
            if window_size is None:
                window_size = self.window_size
            window_size = check_set_window_size(attn_mask_type, window_size)

7549
7550
7551
7552
7553
7554
7555
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
7556

7557
7558
            if qkv_format is None:
                qkv_format = self.qkv_format
7559

7560
7561
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
7562

7563
7564
7565
7566
7567
                # convert causal to causal_bottom_right in inference when KV-caching is in use
                # so users can run with the same attn_mask_type for training and inference
                if attn_mask_type in ["causal", "padding_causal"]:
                    attn_mask_type = attn_mask_type + "_bottom_right"

7568
7569
7570
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
7571

7572
7573
7574
7575
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
7576

7577
7578
7579
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
7580

7581
7582
7583
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
7584

7585
7586
7587
7588
7589
7590
7591
7592
7593
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
7594

7595
7596
7597
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
7598

7599
7600
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
7601
7602

            assert (
7603
7604
7605
7606
7607
7608
7609
7610
7611
7612
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
7613
                assert all(
7614
7615
7616
7617
7618
7619
7620
7621
7622
7623
7624
7625
7626
7627
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
                if max_seqlen_q is None:
7628
7629
7630
7631
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
7632
                    max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
7633
                if max_seqlen_kv is None:
7634
7635
7636
7637
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
7638
                    max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
7639
                batch_size = len(cu_seqlens_q) - 1
7640

7641
7642
7643
7644
7645
7646
            cp_size = 1
            if isinstance(self.cp_group, dist_group_type):
                cp_size = get_distributed_world_size(self.cp_group)
            elif isinstance(self.cp_group, list):
                for group in self.cp_group:
                    cp_size *= get_distributed_world_size(group)
7647
7648
            context_parallel = cp_size > 1

7649
            if qkv_format in ["sbhd", "bshd"]:
7650
                assert all(
7651
7652
7653
7654
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
7655
                    batch_size = query_layer.shape[1]
7656
7657
                if qkv_format == "bshd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
7658
                    batch_size = query_layer.shape[0]
7659
7660
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
7661
7662
7663
7664
7665
7666
7667
7668
7669
7670
7671
7672
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                        the sequence dimention in 'query_layer'!"""
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                        the sequence dimention in 'key_layer' and 'value_layer'!"""
7673
7674
7675
7676
7677
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
7678
                        if self.attention_type == "self":
7679
7680
7681
7682
7683
7684
7685
7686
7687
7688
7689
7690
7691
7692
7693
7694
                            cu_seqlens_q = get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                    else:
                        cu_seqlens_q = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
                        cu_seqlens_kv = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
7695

7696
7697
7698
7699
7700
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
7701
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
7702
7703
7704
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
7705
                qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
7706
7707
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
7708

7709
7710
7711
7712
7713
7714
7715
7716
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
7717
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
7718
7719
7720
7721
7722
7723
7724
7725
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
7726
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
7727
7728
7729
7730
7731
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

7732
7733
            core_attention_bias_shape = None
            if core_attention_bias is not None:
7734
                if (
7735
7736
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
7737
                ):
7738
7739
7740
7741
7742
7743
7744
7745
7746
7747
7748
7749
7750
7751
7752
7753
7754
7755
7756
7757
7758
7759
7760
7761
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

            pad_between_seqs = (
                cu_seqlens_q_padded is not None
                and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
            ) or (
                cu_seqlens_kv_padded is not None
                and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
            )
7762

7763
            attention_params = AttentionParams(
7764
7765
7766
7767
7768
7769
7770
7771
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
7772
7773
                head_dim_qk=query_layer.shape[-1],
                head_dim_v=value_layer.shape[-1],
7774
7775
7776
7777
7778
7779
7780
7781
7782
7783
7784
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
7785
7786
                deterministic=self.deterministic,
                is_training=self.training,
7787
7788
7789
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
7790
            global _attention_backends, _flash_attn_3_plus, _use_flash_attn_3
7791
7792
7793
7794
7795
7796
7797
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
7798
                _use_flash_attn_3 = _flash_attn_3_plus
7799
7800
7801
7802
7803
7804
7805
7806
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
                ) = get_attention_backend(attention_params)
                if use_flash_attention:
7807
7808
7809
7810
                    self.logger.info(
                        "Running with FlashAttention backend (version %s)",
                        _flash_attn_version if not _use_flash_attn_3 else _flash_attn_v3_version,
                    )
7811
7812
7813
7814
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
7815
                    )
7816
7817
7818
7819
7820
7821
7822
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
7823

7824
7825
7826
7827
7828
7829
7830
7831
7832
7833
7834
7835
7836
7837
7838
7839
7840
7841
7842
7843
7844
7845
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
7846
                    cp_comm_type=self.cp_comm_type,
7847
7848
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
7849
7850
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
7851
                )
7852

7853
            if use_fused_attention:
7854
7855
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
7856
7857
7858
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
7859
7860
7861
7862
7863
7864
7865
                    fu_core_attention_bias_type = "post_scale_bias"
                    _, fu_core_attention_bias = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
7866
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
7867
                    )
7868
7869
7870
7871
7872
7873
7874
7875
7876
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
7877
7878
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
7879
7880
7881
7882
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
7883
                        window_size=window_size,
7884
7885
7886
7887
7888
7889
7890
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
7891
                        cp_comm_type=self.cp_comm_type,
7892
7893
7894
7895
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
                    )
                return self.fused_attention(
7896
7897
7898
7899
7900
7901
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
7902
7903
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
7904
7905
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
7906
7907
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
7908
                    window_size=window_size,
7909
                    fused_attention_backend=fused_attention_backend,
7910
7911
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
7912
7913
7914
7915
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
7916
                    cp_comm_type=self.cp_comm_type,
7917
7918
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
7919
                )
7920

7921
            from .cpu_offload import CPUOffloadEnabled
7922

7923
7924
7925
7926
7927
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
7928

7929
            if use_unfused_attention:
7930
7931
7932
7933
7934
7935
                if window_size is not None and (
                    window_size[0] != -1 or window_size[1] not in [-1, 0]
                ):
                    attn_mask_type, attention_mask = get_swa_mask(
                        window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
                    )
7936
7937
7938
7939
7940
7941
7942
7943
7944
7945
7946
7947
7948
7949
7950
7951
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
7952
7953
7954
                    query_layer,
                    key_layer,
                    value_layer,
7955
7956
7957
7958
7959
7960
7961
7962
7963
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
7964

7965
            raise Exception("No dot product attention support for the provided inputs!")
7966
7967


7968
7969
7970
7971
7972
7973
7974
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

7975
7976
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
7977

7978
7979
7980
7981
7982
7983
7984
7985
7986
7987
7988
7989
7990
7991
7992
7993
7994
7995
7996
7997
7998
7999
8000
8001
8002
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
8003
8004
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
8005
                   default = `causal`
8006
8007
8008
8009
8010
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
8011
8012
8013
8014
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
8015
8016
8017
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
8018
                be overridden by :attr:`window_size` in `forward` as well.
8019
8020
8021
8022
8023
8024
8025
8026
8027
8028
8029
8030
8031
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
8032
8033
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
8034
8035
8036
8037
8038
8039
8040
8041
8042
8043
8044
8045
8046
8047
8048
8049
8050
8051
8052
8053
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
8054
          The device on which the parameters of the model will be allocated. It is the user's
8055
8056
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
8057
8058
8059
8060
8061
8062
8063
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
8064
            For that, please use `get_qkv_layout` to gain the layout information.
8065
8066
8067
8068
8069
8070
8071
8072
8073
8074
8075
8076
8077
8078
8079
8080
8081
8082
8083
8084
8085
8086
8087
8088
8089
8090
8091
8092
8093
8094
8095
8096
8097
8098
8099
8100
8101
8102
8103
8104

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
8105
8106
8107
8108
8109
8110
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
8111
8112
8113
8114
8115
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
8116
        layer_number: Optional[int] = None,
8117
        attn_mask_type: str = "causal",
8118
        window_size: Optional[Tuple[int, int]] = None,
8119
8120
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
8121
        num_gqa_groups: Optional[int] = None,
8122
8123
8124
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
8125
        params_dtype: Optional[torch.dtype] = None,
8126
        return_bias: bool = False,
8127
8128
8129
8130
8131
8132
8133
8134
8135
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
8136
        ub_overlap_rs_dgrad: bool = False,
8137
8138
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
8139
        bias: bool = True,
8140
        normalization: str = "LayerNorm",
8141
        device: Union[torch.device, str] = "cuda",
8142
        qkv_format: str = "sbhd",
8143
8144
    ) -> None:
        super().__init__()
8145

8146
        self.qkv_format = qkv_format
8147
        self.attn_mask_type = attn_mask_type
8148
        self.window_size = check_set_window_size(attn_mask_type, window_size)
8149
        self.layer_number = layer_number
8150
8151
8152
8153
8154
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
8155
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
8156
        self.num_attention_heads = num_attention_heads
8157
8158
8159
8160
8161
8162
8163
8164
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
8165
8166
8167
8168
8169

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

8170
8171
8172
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
8173
8174
8175
8176
8177
8178

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
8179
8180
8181
8182
8183
8184
8185
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
8186
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
8187
8188
8189
8190

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
8191
8192
8193
8194
8195
8196
8197

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
8198
            "params_dtype": self.params_dtype,
8199
            "device": device,
8200
8201
8202
8203
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
8204
        if self.attention_type == "self":
8205
8206
            parameters_split = None
            if not fuse_qkv_params:
8207
8208
8209
8210
8211
8212
8213
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
8214
8215
8216
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
8217
                    self.hidden_size_q + 2 * self.hidden_size_kv,
8218
8219
8220
8221
8222
8223
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
8224
                    parameters_split=parameters_split,
8225
8226
8227
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
8228
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
8229
                    ub_overlap_ag=ub_overlap_ag,
8230
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8231
                    ub_name="qkv",
8232
8233
8234
8235
8236
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
8237
                    self.hidden_size_q + 2 * self.hidden_size_kv,
8238
8239
8240
8241
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
8242
                    parameters_split=parameters_split,
8243
8244
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
8245
        elif self.attention_type == "cross":
8246
8247
8248
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
8249
                    self.hidden_size_q,
8250
8251
8252
8253
8254
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
8255
                    parameters_split=("query",) if not fuse_qkv_params else None,
8256
8257
8258
8259
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
8260
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
8261
                    ub_overlap_ag=ub_overlap_ag,
8262
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8263
                    ub_name="qkv",
8264
8265
8266
8267
8268
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
8269
                    self.hidden_size_q,
8270
8271
8272
8273
8274
8275
8276
8277
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
8278
                2 * self.hidden_size_kv,
8279
8280
8281
8282
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
8283
                parameters_split=("key", "value") if not fuse_qkv_params else None,
8284
8285
8286
8287
8288
8289
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
8290
            self.hidden_size_per_attention_head,
8291
8292
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
8293
            qkv_format=self.qkv_format,
8294
8295
8296
8297
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
8298
            layer_number=self.layer_number,
8299
            attention_type=self.attention_type,
8300
8301
8302
8303
        )

        # Linear
        self.proj = Linear(
8304
            self.hidden_size_q,
8305
8306
8307
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
8308
            return_bias=return_bias,
8309
            parallel_mode="row" if set_parallel_mode else None,
8310
8311
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
8312
            ub_name="proj",
8313
8314
8315
8316
            **common_gemm_kwargs,
        )

    def _allocate_memory(
8317
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
8318
8319
8320
8321
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
8322
            self.num_gqa_groups_per_partition,
8323
            self.hidden_size_per_attention_head,
8324
            dtype=dtype,
8325
8326
8327
8328
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
8329
8330
8331
8332
8333
8334
8335
8336
8337
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
8338
8339
        self.tp_group = tp_group

8340
    def set_context_parallel_group(
8341
        self,
8342
        cp_group: Union[dist_group_type, List[dist_group_type], None],
8343
        cp_global_ranks: List[int],
8344
        cp_stream: torch.cuda.Stream,
8345
        cp_comm_type: str = "p2p",
8346
    ) -> None:
8347
8348
8349
8350
8351
8352
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
8353
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
8354
                  context parallel process group.
8355
8356
8357
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
8358
8359
8360
8361
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
8362
        cp_comm_type : str, default = `p2p`
8363
                      inter-gpu communication type for context parallelism.
8364
                      Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
8365
8366
8367
8368
8369
8370
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
8371
8372
8373
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
8374
        """
8375
8376
8377
8378
8379
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
8380
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
8381

8382
8383
8384
    def forward(
        self,
        hidden_states: torch.Tensor,
8385
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8386
        encoder_output: Optional[torch.Tensor] = None,
8387
        attn_mask_type: Optional[str] = None,
8388
        window_size: Optional[Tuple[int, int]] = None,
8389
8390
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
8391
        inference_params: Optional[InferenceParams] = None,
8392
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
8393
8394
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
8395
        alibi_slopes: Optional[torch.Tensor] = None,
8396
8397
8398
8399
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
8400
        fast_zero_fill: bool = True,
8401
    ) -> Tuple[Union[torch.Tensor, None], ...]:
8402
8403
8404
8405
8406
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

8407
8408
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
8409
8410
8411
8412
8413

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
8414
8415
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
8416
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
8417
8418
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
8419
8420
8421
8422
8423
8424
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
8425
                       default = `None`
8426
8427
8428
8429
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
8430
8431
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
8432
8433
8434
8435
8436
8437
8438
8439
8440
8441
8442
8443
8444
8445
8446
8447
8448
8449
8450
8451
8452
8453
8454
8455
8456
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
8457
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
8458
        core_attention_bias: Optional[torch.Tensor], default = `None`
8459
8460
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
8461
8462
8463
8464
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
8465
8466
8467
8468
8469
8470
8471
8472
8473
8474
8475
8476
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
8477
8478
8479
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
8480
8481
        # hidden_states: [sq, b, h]

8482
        if attn_mask_type is None:
8483
            attn_mask_type = self.attn_mask_type
8484
8485
        if window_size is None:
            window_size = self.window_size
8486
        window_size = check_set_window_size(attn_mask_type, window_size)
8487

8488
        if "padding" in attn_mask_type and attention_mask is not None:
8489
            for i, _ in enumerate(attention_mask):
8490
8491
8492
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
8493

8494
8495
8496
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
8497

8498
        # =================================================
8499
        # Pre-allocate memory for key-values for inference
8500
8501
8502
        # =================================================

        if inference_params and self.layer_number is not None:
8503
8504
8505
            assert (
                self.qkv_format != "thd"
            ), "qkv_format == thd is not supported for an inference with KV-cache!"
8506
            if self.layer_number not in inference_params.key_value_memory_dict:
8507
                inf_max_seq_len = inference_params.max_sequence_length
8508
8509
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
8510
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8511
8512
                )
                inference_value_memory = self._allocate_memory(
8513
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
8514
8515
8516
8517
8518
8519
8520
8521
8522
8523
8524
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

8525
        # ======================
8526
        # Query, Key, and Value
8527
        # ======================
8528

8529
8530
8531
8532
8533
        fp8_mha = (
            FP8GlobalStateManager.is_fp8_enabled()
            and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
        )

cyanguwa's avatar
cyanguwa committed
8534
8535
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
8536
8537
8538
8539
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8540
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8541
8542
8543
8544
8545
8546
8547
8548
8549
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8550
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8551
8552
                )

8553
8554
8555
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
8556
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
8557
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
8558
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
8559
8560
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
8561
8562
8563
8564
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
8565
8566
8567
8568
8569
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
8570
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
8571
8572
8573
                )
                # split along third last dimension
                split_dim = -3
8574
8575
8576

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
8577
8578
8579
8580
8581
8582
8583
8584
8585
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
8586
                )
8587
            else:
cyanguwa's avatar
cyanguwa committed
8588
                query_layer, key_layer, value_layer = torch.split(
8589
8590
8591
8592
                    mixed_x_layer,
                    (num_queries_per_key_value, 1, 1),
                    dim=split_dim,
                )
cyanguwa's avatar
cyanguwa committed
8593

8594
8595
8596
8597
8598
8599
8600
8601
8602
8603
8604
8605
            if self.qkv_format == "thd":
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
            else:
                # query: -> [sq, b, np, hn]
                # key, value: -> [sq, b, ng, hn]
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
cyanguwa's avatar
cyanguwa committed
8606
8607
        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
8608
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
8609
                encoder_output,
8610
                is_first_microbatch=is_first_microbatch,
8611
                fp8_output=fp8_mha and rotary_pos_emb is None,
8612
8613
8614
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
8615
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
8616
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
8617
                    self.num_gqa_groups_per_partition,
8618
8619
8620
8621
8622
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
8623
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
8624
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
8625
                    2 * self.num_gqa_groups_per_partition,
8626
8627
8628
8629
8630
8631
8632
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
8633
8634
8635
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
8636
8637
8638
                    mixed_kv_layer,
                    split_dim,
                    mixed_kv_layer.shape[split_dim] // 2,
cyanguwa's avatar
cyanguwa committed
8639
                )
8640
            else:
cyanguwa's avatar
cyanguwa committed
8641
                key_layer, value_layer = torch.split(
8642
8643
8644
                    mixed_kv_layer,
                    mixed_kv_layer.shape[split_dim] // 2,
                    dim=split_dim,
cyanguwa's avatar
cyanguwa committed
8645
                )
8646
8647
8648
8649
8650
8651
8652
8653
8654
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
8655
8656
8657
8658
8659
8660

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8661
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8662
8663
8664
8665
8666
8667
8668
8669
8670
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
8671
                    fp8_output=fp8_mha and rotary_pos_emb is None,
8672
8673
8674
8675
8676
8677
8678
8679
8680
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

8681
8682
8683
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
8684

8685
        if rotary_pos_emb is not None:
8686
8687
8688
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
8689
            # duplicate the pos_emb for self attention
8690
            if not isinstance(rotary_pos_emb, tuple):
8691
                rotary_pos_emb = (rotary_pos_emb,) * 2
8692
8693

            q_pos_emb, k_pos_emb = rotary_pos_emb
8694
8695
8696
8697
8698
8699
8700
8701
8702
8703
8704
8705
8706
8707

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

8708
8709
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
8710

8711
8712
8713
8714
        # ===========================
        # Core attention computation
        # ===========================

8715
8716
8717
8718
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
8719
            qkv_format=self.qkv_format,
8720
8721
8722
8723
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
8724
8725
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
8726
            window_size=window_size,
8727
8728
8729
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
8730
            alibi_slopes=alibi_slopes,
8731
            fast_zero_fill=fast_zero_fill,
8732
            inference_params=inference_params,
8733
8734
        )

8735
        # ===================
8736
        # Output. [sq, b, h]
8737
        # ===================
8738

8739
        projection_output = self.proj(
8740
8741
            context_layer,
            is_first_microbatch=is_first_microbatch,
8742
8743
        )

8744
8745
8746
8747
8748
8749
8750
8751
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
8752
        if self.input_layernorm and self.return_layernorm_output:
8753
8754
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]