attention.py 294 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
import math
10
import os
11
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
import warnings
13
import logging
14

15
from dataclasses import dataclass, fields
cyanguwa's avatar
cyanguwa committed
16
import numpy as np
17
from packaging.version import Version as PkgVersion
18
19

import torch
20
import torch.nn.functional as F
21

22
import transformer_engine_torch as tex
23
24
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
25
26
27
28
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
29
30
31
32
33
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
34
35
    fused_attn_fwd,
    fused_attn_bwd,
36
37
38
39
40
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
41
42
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
43
from transformer_engine.pytorch.module import LayerNormLinear, Linear
44
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
45
46
47
48
49
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
50
    get_default_init_method,
51
52
53
54
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
55
    AttnBiasTypes,
56
    QKVLayouts,
57
    dist_group_type,
58
    TE_DType,
59
60
61
62
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
63
    get_distributed_rank,
64
    checkpoint,
65
66
67
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
68
69
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
70
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
71
72
from transformer_engine.pytorch.graph import is_graph_capturing

73

74
75
76
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
77
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
78
79
80
81
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
82
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
83

84
if _flash_attn_version >= _flash_attn_version_required:
85
86
87
88
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
89

90
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
91
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
92
93
94
95
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
96

97
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
98
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
99
100
101
102
103
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
104
105
    format="[%(levelname)-8s | %(name)-19s]: %(message)s",
    level=log_levels[log_level if log_level in [0, 1, 2] else 2],
106
107
)

108
109
110
111
112
113
114
115
116
117
118
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))

_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
119
}
120
121


122
123
@dataclass(eq=True)
class AttentionParams:
124
    """
125
    Attention parameters used to determine which backend to be used.
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    Parameters
    ----------
    qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
        Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
    qkv_dtype: torch.dtype, default = `torch.bfloat16`
        Data type of query/key/value tensors.
    qkv_layout: str, default = "sbh3d"
        Query/key/value tensor memory layout.
    batch_size: int, default = 1
        Batch size.
    num_heads: int, default = 16
        Number of attention heads in the query tensor.
    num_gqa_groups: int, default = 16
        Number of attention heads in key and value tensors.
    max_seqlen_q: int, default = 128
        Maximum sequence length of the query tensor.
    max_seqlen_kv: int, default = 128
        Maximum sequence length of the key and value tensors.
145
146
147
148
    head_dim_qk: int, default = 64
        The size of each attention head in query and key tensors.
    head_dim_v: int, default = 64
        The size of each attention head in the value tensor.
149
150
151
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
        `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
152
    window_size: Tuple[int, int], default = None
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        Sliding window attention size.
    alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
        Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
    core_attention_bias_type: str, default = `no_bias`
        Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
    core_attention_bias_shape: str, default = `1hss`
        Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
    core_attention_bias_requires_grad: bool, default = `True`
        Whether attention bias requires gradient.
    pad_between_seqs: bool, default = `False`
        Whether there is padding between sequences in a batch.
        This only applies to `qkv_format=thd`.
    attention_dropout: float, default = 0.0
        Attention dropout.
    context_parallel: bool, default = `False`
        Whether context parallelism is used or not.
    deterministic: bool, default = `False`
        Whether to run `DotProductAttention` with determinism or not.
171
172
    is_training: bool, default = `True`
        Whether in training mode (`True`) or inference mode (`False`)
173
174
175
176
    fp8: bool, default = `False`
        Whether `DotProductAttention` is in an `fp8_autocast` region.
    fp8_meta: Optional[Dict[str Any]], default = `None`
        The FP8 metadata tensor of `DotProductAttention`.
177
178
179
180
181
182
183
184
185
186
    """

    qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
    qkv_dtype: torch.dtype = torch.bfloat16
    qkv_layout: str = "sbh3d"
    batch_size: int = 1
    num_heads: int = 16
    num_gqa_groups: int = 16
    max_seqlen_q: int = 128
    max_seqlen_kv: int = 128
187
188
    head_dim_qk: int = 64
    head_dim_v: int = 64
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    attn_mask_type: str = "no_mask"
    window_size: Union[Tuple[int, int], None] = None
    alibi_slopes_shape: Union[torch.Size, List, None] = None
    core_attention_bias_type: str = "no_bias"
    core_attention_bias_shape: str = "1hss"
    core_attention_bias_requires_grad: bool = True
    pad_between_seqs: bool = False
    attention_dropout: float = 0.0
    context_parallel: bool = False
    deterministic: bool = False
    is_training: bool = True
    fp8: bool = False
    fp8_meta: Union[Dict[str, Any], None] = None


_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


def get_attention_backend(
    attention_params: AttentionParams = None,
):
    """
    Select the appropriate attention backend/sub-backend based on user input and runtime environment.

    Parameters
    ----------
    See `AttentionParams`.
228
229
230
231
232
233
234

    Returns
    ----------
    use_flash_attention: bool
        Whether the `FlashAttention` backend has been selected.
    use_fused_attention: bool
        Whether the `FusedAttention` backend has been selected.
235
236
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend
        If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
237
238
239
240
241
242
    use_unfused_attention: bool
        Whether the `UnfusedDotProductAttention` backend has been selected.
    available_backends: List[bool]
        All available backends that could support the provided input. A list of Booleans
        in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
    """
243
244
245
246
247
248
249
250
    qkv_type = attention_params.qkv_type
    qkv_dtype = attention_params.qkv_dtype
    qkv_layout = attention_params.qkv_layout
    batch_size = attention_params.batch_size
    num_heads = attention_params.num_heads
    num_gqa_groups = attention_params.num_gqa_groups
    max_seqlen_q = attention_params.max_seqlen_q
    max_seqlen_kv = attention_params.max_seqlen_kv
251
252
    head_dim_qk = attention_params.head_dim_qk
    head_dim_v = attention_params.head_dim_v
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    attn_mask_type = attention_params.attn_mask_type
    window_size = attention_params.window_size
    alibi_slopes_shape = attention_params.alibi_slopes_shape
    core_attention_bias_type = attention_params.core_attention_bias_type
    core_attention_bias_shape = attention_params.core_attention_bias_shape
    core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad
    pad_between_seqs = attention_params.pad_between_seqs
    attention_dropout = attention_params.attention_dropout
    context_parallel = attention_params.context_parallel
    deterministic = attention_params.deterministic
    is_training = attention_params.is_training
    fp8 = attention_params.fp8
    fp8_meta = attention_params.fp8_meta

    # Run config
268
    logger = logging.getLogger("DotProductAttention")
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    device_compute_capability = get_device_compute_capability()
    cudnn_version = get_cudnn_version()
    run_config = {
        "transformer_engine_version": te.__version__,
        "compute_capability": "sm"
        + str(
            (lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1])
        ),
        "flash_attn_version": _flash_attn_version,
        "cudnn_version": ".".join([str(i) for i in cudnn_version]),
    }
    attention_params_dict = {
        field.name: getattr(attention_params, field.name) for field in fields(attention_params)
    }
    run_config.update(attention_params_dict)
    if fp8:
        run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
    logger.debug("Running with config=%s", run_config)
287
288

    # Filter: Environment variables
289
290
291
292
293
294
295
    global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN
    _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
    _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
    _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
    use_flash_attention = _NVTE_FLASH_ATTN
    use_fused_attention = _NVTE_FUSED_ATTN
    use_unfused_attention = _NVTE_UNFUSED_ATTN
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    if not use_flash_attention:
        logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
    if not use_fused_attention:
        logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
    if not use_unfused_attention:
        logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")

    # Filter: ONNX mode
    if is_in_onnx_export_mode():
        if use_flash_attention:
            logger.debug("Disabling FlashAttention due to ONNX mode")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention due to ONNX mode")
        use_fused_attention = False

    # Filter: Compute capability
    if device_compute_capability < (8, 0):
        if use_flash_attention:
            logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
            use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
            use_fused_attention = False

    # Filter: Context parallelism
    if context_parallel and use_unfused_attention:
        logger.debug(
            "Disabling UnfusedDotProductAttention as it does not support context parallelism"
        )
        use_unfused_attention = False

    # Filter: Data type
    if use_flash_attention and (
        qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor
    ):
        logger.debug(
            "Disabling FlashAttention due to unsupported QKV data type. "
            "Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. "
            "Found: qkv_type = %s, qkv_dtype = %s.",
            qkv_type,
            qkv_dtype,
        )
        use_flash_attention = False
    if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]):
        logger.debug(
            "Disabling FusedAttention due to unsupported QKV data type. "
            "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
            "Found: qkv_dtype = %s.",
            qkv_dtype,
        )
        use_fused_attention = False

    # Filter: Execution type
    if fp8 and fp8_meta["recipe"].fp8_dpa:
        if use_flash_attention:
            logger.debug("Disabling FlashAttention as it does not support FP8")
            use_flash_attention = False
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
            use_unfused_attention = False

    # Filter: Head dimension
359
360
361
    if use_flash_attention and head_dim_qk != head_dim_v:
        logger.debug("Disabling FlashAttention as it does not support MLA.")
        use_flash_attention = False
362
    if use_flash_attention and (
363
364
365
        head_dim_qk > 256
        or head_dim_qk % 8 != 0
        or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
366
367
    ):
        logger.debug(
368
369
370
371
372
373
            "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
            "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
            "head_dim_qk <= 256 (>192 requires sm80/90). "
            "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
            head_dim_qk,
            head_dim_v,
374
375
376
            ".".join([str(i) for i in device_compute_capability]),
        )
        use_flash_attention = False
377
378
379
380
381
382
383
    qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
    if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
        logger.debug(
            "Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
            qkv_layout,
        )
        use_fused_attention = False
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447

    # Filter: QKV layout
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        if use_unfused_attention:
            logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
            use_unfused_attention = False
        if use_flash_attention and pad_between_seqs:
            logger.debug(
                "Disabling FlashAttention for qkv_format = thd when there is "
                "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
            )
            use_flash_attention = False

    # Filter: Attention mask
    # attn_mask_type               |     supported backends
    # -------------------------------------------------------------------
    # no_mask                      |     All
    # padding                      |     FlashAttention, FusedAttention
    # causal                       |
    #     self-attention           |     All
    #     cross-attention          |     FusedAttention
    # padding_causal               |
    #     self-attention           |     FlashAttention, FusedAttention
    #     cross-attention          |     FusedAttention
    # causal_bottom_right          |     All
    # padding_causal_bottom_right  |     FlashAttention, FusedAttention
    # arbitrary                    |     UnfusedDotProductAttention
    if attn_mask_type == "arbitrary":
        if use_flash_attention:
            logger.debug("Disabling FlashAttention for arbitrary mask")
        use_flash_attention = False
        if use_fused_attention:
            logger.debug("Disabling FusedAttention for arbitrary mask")
        use_fused_attention = False
    if use_unfused_attention and "padding" in attn_mask_type:
        logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
        use_unfused_attention = False
    if (
        use_flash_attention
        and _flash_attn_2_1_plus
        and attn_mask_type in ["causal", "padding_causal"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention as it only supports bottom-right-diagonal "
            "causal mask since flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        use_flash_attention = False
    if (
        use_flash_attention
        and not _flash_attn_2_1_plus
        and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
        and max_seqlen_q != max_seqlen_kv
    ):
        logger.warning(
            "Disabling FlashAttention as it only supports top-left-diagonal "
            "causal mask before flash-attn 2.1. See "
            "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
        )
        use_flash_attention = False

    # Filter: Sliding window attention
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    #    backend                 |      window_size       | diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | (-1, -1) or (>=0, >=0) | bottom right
    # FusedAttention             | (-1,  0) or (>=0, 0)   | top left
    # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
    #                            |                        | converts window_size to an 'arbitrary' mask
    if window_size is None:
        window_size = check_set_window_size(attn_mask_type, window_size)
    else:
        if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
            if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention"
                    " for FP8"
                )
                use_fused_attention = False
            elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
                logger.debug(
                    "Disabling FusedAttention as it only supports sliding window attention "
                    "with causal mask, no dropout, and qkv_format = bshd/sbhd"
                )
                use_fused_attention = False
            elif context_parallel:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with context parallelism"
                )
                use_fused_attention = False
            elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
                "no_mask",
                "padding",
                "causal_bottom_right",
                "padding_causal_bottom_right",
            ]:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s for cross-attention",
                    attn_mask_type,
                )
                use_fused_attention = False
            elif "padding" in attn_mask_type:
                logger.debug(
                    "Disabling FusedAttention as it does not support sliding window attention "
                    "with attn_mask_type = %s",
                    attn_mask_type,
                )
                use_fused_attention = False
        if (
            use_flash_attention
            and (window_size[0] != -1 or window_size[1] not in [-1, 0])
            and (not _flash_attn_2_3_plus or context_parallel)
        ):
500
501
502
503
504
505
506
            logger.debug(
                "Disabling FlashAttention as sliding window attention requires "
                "flash-attn 2.3+ and no context parallelism"
            )
            use_flash_attention = False

    # Filter: Attention bias
507
508
509
510
511
512
513
514
    #    backend                 |      bias types              | ALiBi diagonal alignment
    # ---------------------------------------------------------------------------------
    # FlashAttention             | no_bias, alibi/alibi_slopes  | bottom right
    # FusedAttention             | no_bias, post_scale_bias     |
    #                            | alibi/alibi_slopes           | top left,
    #                            |                              | bottom_right (converts to a 'post_scale_bias' bias)
    # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
    #                            | alibi/alibi_slopes           | both; converts to a 'post_scale_bias' bias
515
516
517
518
519
520
521
522
523
524
525
526
527
    if use_flash_attention and (
        core_attention_bias_type not in ["no_bias", "alibi"]
        or core_attention_bias_shape is not None
    ):
        logger.debug("Disabling FlashAttention for pre/post_scale_bias")
        use_flash_attention = False

    fu_core_attention_bias_type = core_attention_bias_type
    fu_core_attention_bias_shape = core_attention_bias_shape
    fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad
    if (
        use_fused_attention
        and core_attention_bias_type == "alibi"
528
        and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
529
530
531
    ):
        fu_core_attention_bias_type = "post_scale_bias"
        fu_core_attention_bias_requires_grad = False
532
533
534
535
536
        if alibi_slopes_shape is None:
            fu_core_attention_bias_shape = "1hss"
        elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
            fu_core_attention_bias_shape = "1hss"
        elif (
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            len(alibi_slopes_shape) == 2
            and alibi_slopes_shape[0] == batch_size
            and alibi_slopes_shape[1] == num_heads
        ):
            fu_core_attention_bias_shape = "bhss"

    if (
        use_fused_attention
        and fu_core_attention_bias_type == "post_scale_bias"
        and fu_core_attention_bias_shape != "1hss"
    ):
        if fu_core_attention_bias_requires_grad:
            # remove this line when cuDNN adds bwd support for
            # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
            logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
            use_fused_attention = False
        else:
            # max512 backend will only support [1, h, s, s]
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

    # Filter: cuDNN support
    fused_attention_backend = None
    if use_fused_attention:
        q_type = TE_DType[qkv_dtype]
        kv_type = q_type
        if fp8 and fp8_meta["recipe"].fp8_dpa:
            q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            kv_type = q_type
        fused_attention_backend = tex.get_fused_attn_backend(
            q_type,
            kv_type,
            QKVLayout[qkv_layout],
            AttnBiasType[fu_core_attention_bias_type],
            AttnMaskType[attn_mask_type],
            attention_dropout,
            num_heads,
            num_gqa_groups,
            max_seqlen_q,
            max_seqlen_kv,
576
577
            head_dim_qk,
            head_dim_v,
578
579
            window_size[0],
            window_size[1],
580
        )
581
        if fused_attention_backend == FusedAttnBackend["No_Backend"]:
582
583
            logger.debug("Disabling FusedAttention as no backend supports the provided input")
            use_fused_attention = False
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
            fused_attention_backend = None
        if (
            use_fused_attention
            and context_parallel
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "context parallellism",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and window_size is not None
            and window_size[0] != -1
            and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
        ):
            logger.debug(
                "Disabling FusedAttention as only sub-backend %s does not support "
                "slidng window attention",
                int(fused_attention_backend),
            )
            use_fused_attention = False
            fused_attention_backend = None
        if (
            use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
613
614
615
616
617
618
619
620
            and fu_core_attention_bias_type == "post_scale_bias"
            and fu_core_attention_bias_shape != "1hss"
        ):
            logger.debug(
                "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in"
                " [1, H, S, S] shape"
            )
            use_fused_attention = False
621
            fused_attention_backend = None
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641

    # Filter: Determinism
    # backend                      | deterministic
    # ---------------------------------------------
    # FlashAttention               |
    #     flash-attn >=2.0, <2.4.1 | no
    #     flash-attn >=2.4.1       | yes
    # FusedAttention               |
    #     sub-backend 0            | yes
    #     sub-backend 1            | workspace optimization path and sm90+: yes;
    #                              | otherwise: no
    #     sub-backend 2            | no
    # UnfusedDotProductAttention   | yes
    if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus:
        logger.warning(
            "Disabling FlashAttention as version <2.4.1 does not support deterministic "
            "execution. To use FlashAttention with deterministic behavior, "
            "please install flash-attn >= 2.4.1."
        )
        use_flash_attention = False
642
643
644
645
646
647
648
649
650
651
652
    if use_fused_attention and deterministic:
        if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (
                device_compute_capability < (9, 0)
                or core_attention_bias_requires_grad
                or cudnn_version < (8, 9, 5)
653
            )
654
655
656
        ):
            logger.debug("Disabling FusedAttention for determinism reasons")
            use_fused_attention = False
657
658
659

    # All available backends
    available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
660
661
662
663
664
665
666
667
668
669
670
671
    logger.debug(
        "Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
        " UnfusedDotProductAttention=%s}",
        bool(available_backends[0]),
        bool(available_backends[1]),
        (
            f" (sub-backend {int(fused_attention_backend)})"
            if fused_attention_backend is not None
            else ""
        ),
        bool(available_backends[2]),
    )
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691

    # Select FusedAttention for performance
    if (
        use_flash_attention
        and use_fused_attention
        and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
    ):
        if device_compute_capability == (9, 0):
            logger.debug(
                "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                "for performance reasons"
            )
            use_flash_attention = False

    # Selected backend
    if use_flash_attention:
        use_fused_attention = False
        use_unfused_attention = False
    elif use_fused_attention:
        use_unfused_attention = False
692
    selected_backend = "NoBackend"
693
694
695
696
697
698
    if use_flash_attention:
        selected_backend = "FlashAttention"
    elif use_fused_attention:
        selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
    elif use_unfused_attention:
        selected_backend = "UnfusedDotProductAttention"
699
    logger.debug("Selected backend = %s", selected_backend)
700

701
702
703
704
705
706
    global _attention_backends
    _attention_backends["use_flash_attention"] = use_flash_attention
    _attention_backends["use_fused_attention"] = use_fused_attention
    _attention_backends["fused_attention_backend"] = fused_attention_backend
    _attention_backends["use_unfused_attention"] = use_unfused_attention
    _attention_backends["backend_selection_requires_update"] = False
707
708
709
710

    return (
        use_flash_attention,
        use_fused_attention,
711
        fused_attention_backend,
712
713
714
715
716
        use_unfused_attention,
        available_backends,
    )


717
class InferenceParams:  # pylint: disable=too-few-public-methods
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
    """
    Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference.

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
761

762

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
@torch.no_grad()
def get_swa_mask(
    window_size: Tuple[int, int],
    max_seqlen_q: int,
    max_seqlen_kv: int,
    attn_mask_type: str = "no_mask",
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
    """
    Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
    For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
    and for other mask types, the bottom right corner.

    Parameters
    ----------
    window_size: Tuple[int, int]
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
        window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
        map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
        `attn_mask_type`.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    attn_mask_type: str, default = `no_mask`
        Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
        "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
    attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
        default = `None`
        Boolean tensor(s) used to mask out attention softmax input.

    Returns
    ----------
    attention_mask: torch.Tensor
        Combined `attention_mask` (input) and sliding window attention mask.
        The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
        else, the same shape as input `attention_mask`.
    """
    mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
    if attn_mask_type in ["causal"]:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_q
        right = window_size[1] if window_size[1] != -1 else max_seqlen_q
        mask_upper = torch.triu(mask, diagonal=-left)
        mask_lower = torch.tril(mask_upper, diagonal=right)
    else:
        left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
        right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
        mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
        mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
    attn_mask_type = "arbitrary"
    mask = mask_lower.logical_not()
    if attention_mask is not None:
        mask = torch.logical_and(attention_mask, mask)
    return attn_mask_type, mask


821
822
823
824
825
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
826
827
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
828
    bottom_right_alignment: bool = True,
829
) -> Tuple[torch.Tensor, torch.Tensor]:
830
    """
831
832
833
834
835
836
837
838
839
840
841
842
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
843
844
845
    bottom_right_alignment: bool, default = `True`
        Whether to align the diagonal of the ALiBi bias to the bottom right corner of
        the matrix (`True`) or top left (`False`).
846

847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
        ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
        then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
        `alibi_slopes` is in [batch_size, num_heads], then the bias is in
        [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
        if _alibi_cache["_alibi_slopes"].dim() == 2:
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
881
882
883
884
885
886
887
888
        if bottom_right_alignment:
            bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(
                1, 1, 1, max_seqlen_kv
            )
        else:
            bias = torch.arange(
                1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda"
            ).view(1, 1, 1, max_seqlen_kv)
889
890
891
        bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(
            1, 1, max_seqlen_q, 1
        )
892
893
894
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
895
        _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment
896
897
898
899
900
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
901
902
903
904
905
906
907
908
909


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
910
    reduced_mask = mask.logical_not().sum(dim=1)
911
912
913
914
915
916
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

917

918
919
920
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
921
922
923
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
924
925
926
927
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

928
    reduced_mask = mask.logical_not().sum(dim=1)
929
930
931
932
933
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
934
    indices = mask.logical_not().nonzero()
935
936
937
938
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
939
940
941
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
942
943
944
945

    return cu_seqlens, indices


946
947
948
949
950
951
952
953
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
954
955
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
956
957
958

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
959
960
961
962
963
964
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
965
966
967

    return indices

968

969
_cu_seqlens_cache = {}
970
971


972
973
974
975
976
977
978
979
980
981
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
982
983
984
985
986
987
988
989
990
991
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
992
993


994
995
996
997
998
999
1000
1001
1002
@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
1003
1004
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
1053
1054
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
1055
    unpacked.scatter_(0, indices, tensor)
1056
    unpacked = unpacked[0:-1, :, :]
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
1096

1097
1098
    @staticmethod
    def forward(
1099
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
1100
1101
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
1102
        ctx.save_for_backward(indices)
1103
1104
1105
1106
1107
1108
1109
1110
1111
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
1112
        (indices,) = ctx.saved_tensors
1113
        if len(grad_outputs) == 1:
1114
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
1115
        if len(grad_outputs) == 2:
1116
1117
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
1118
1119
1120
1121
1122
1123


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
1124

1125
1126
1127
1128
1129
1130
1131
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
1132
        ctx.save_for_backward(indices)
1133
1134
1135
1136
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
1137
1138
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
1139
1140


1141
1142
1143
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
1144
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
1145
1146
1147
1148
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
1149
1150
1151
1152
1153
1154
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
1155
1156
1157
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
1158
1159
1160
1161
1162
1163
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


1183
@jit_fuser
1184
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
1185
    """Merge partial outputs of each step in Attention with context parallelism"""
1186
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
1187
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
1188
    out_corrected = out_per_step * softmax_lse_corrected_exp
1189
1190
1191
    out.add_(out_corrected)


1192
@jit_fuser
1193
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
1194
    """Merge softmax stats of each step in Attention with context parallelism"""
1195
1196
1197
1198
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
1199
1200


1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
@jit_fuser
def get_cu_seqlens_on_cp_rank(
    cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


1222
class AttnFuncWithCP(torch.autograd.Function):
1223
    """
1224
1225
    Attention implementation with context parallelism.
    Split attention compute into multiple steps, and overlap current-step
1226
1227
1228
1229
    compute with next-step communication.
    """

    @staticmethod
1230
1231
1232
1233
1234
1235
1236
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
1237
        cu_seqlens_kv,
1238
        max_seqlen_q,
1239
        max_seqlen_kv,
1240
1241
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
        dropout_p,
        cp_group,
        cp_global_ranks,
        cp_stream,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ):
1254
1255
1256
1257
1258
1259
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
1260
        recv_src = cp_global_ranks[(rank - 1) % cp_size]
1261
1262
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1263
1264
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
1265

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
        if qkv_format in ["bshd", "sbhd"]:
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
        pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
        cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
        cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
1279

1280
        if causal:
1281
1282
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
1283
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
1284
1285
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
1286
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
1287
1288
1289
        total_tokens_kv = None if qkv_format != "thd" else k.shape[0]
        # remove padded tokens at the end
        k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]]
1290
        if attn_bias is not None:
1291
            assert len(attn_bias.shape) == 4, (
1292
1293
1294
1295
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
1296
1297
1298
1299
1300
1301
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
1302
1303
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
1304
1305
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
1306
            )
1307
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
1308
1309
1310
1311
1312
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
1313
1314
        if _flash_attn_2_5_7_plus:
            fa_optional_forward_kwargs["block_table"] = None
1315

1316
1317
1318
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
1319
        attn_bias_inputs = [None, None]
1320
1321
1322
1323
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
1324
        attn_biases = [None for _ in range(cp_size)]
1325
1326
1327
1328
1329
1330
1331

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
1332
1333
1334
1335
        if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
1336
1337
        send_recv_reqs = [[], []]

1338
        for i in range(cp_size + 1):
1339
            if i < cp_size:
1340
                with torch.cuda.stream(flash_attn_streams[i % 2]):
1341
                    # wait until KV is received
1342
                    for req in send_recv_reqs[(i + 1) % 2]:
1343
1344
                        req.wait()

1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

                    kv_inputs[i % 2] = p2p_comm_buffers[i]
1358
1359
                    if causal:
                        if i == 0:
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1372
                            if use_fused_attention:
1373
1374
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1375
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1376
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
1377
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
1378
                                        k.shape[0], -1, 2, *k.shape[-2:]
1379
                                    )
1380
1381
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1382
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1383
1384
1385
1386
                                    # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        -1, k.shape[2], 2, *k.shape[-2:]
                                    )
1387
                                elif qkv_format == "thd":
1388
                                    q_inputs[i % 2] = q
1389
1390
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1391
1392
1393
1394
1395
1396
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1397
                                    ).contiguous()
1398
1399
1400
1401
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q,
1402
1403
1404
                                        max_seqlen_kv,
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
1405
                                        q_inputs[i % 2],
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
                                        (
                                            kv_inputs[i % 2][..., 0, :, :]
                                            if qkv_format in ["bshd", "sbhd"]
                                            else kv_inputs[i % 2][0]
                                        ),
                                        (
                                            kv_inputs[i % 2][..., 1, :, :]
                                            if qkv_format in ["bshd", "sbhd"]
                                            else kv_inputs[i % 2][1]
                                        ),
1416
1417
1418
1419
1420
1421
1422
1423
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type=attn_mask_type,
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
1424
1425
                                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1426
                                    )
1427
                                )
1428
1429
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
1430
1431
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1432
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1433
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1448
1449
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1450
                                    max_seqlen_q,
1451
                                    max_seqlen_kv,
1452
1453
1454
1455
1456
                                    dropout_p,
                                    softmax_scale,
                                    causal=True,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1457
                                )
1458
                        elif i <= rank:
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
1476
                            if use_fused_attention:
1477
1478
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
1479
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
1480
1481
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
1482
1483
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
1484
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
1485
1486
                                    # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous()
1487
                                elif qkv_format == "thd":
1488
                                    q_inputs[i % 2] = q
1489
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
1490
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
1491
                                        kv_inputs[i % 2], cu_seqlens_kv_padded, 0
1492
                                    )
1493
1494
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1495
1496
1497
1498
1499
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q,
1500
1501
1502
                                        max_seqlen_kv // 2,
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
1503
                                        q_inputs[i % 2],
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
                                        (
                                            kv_inputs[i % 2][..., 0, :, :]
                                            if qkv_format in ["bshd", "sbhd"]
                                            else kv_inputs[i % 2][0]
                                        ),
                                        (
                                            kv_inputs[i % 2][..., 1, :, :]
                                            if qkv_format in ["bshd", "sbhd"]
                                            else kv_inputs[i % 2][1]
                                        ),
1514
1515
1516
1517
1518
1519
1520
1521
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type="padding" if padding else "no_mask",
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
1522
1523
1524
1525
1526
                                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                                        cu_seqlens_kv_padded=(
                                            None
                                            if cu_seqlens_kv_padded is None
                                            else cu_seqlens_kv_padded // 2
1527
1528
                                        ),
                                    )
1529
                                )
1530
1531
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
1532
1533
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
1534
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1535
1536
                                if qkv_format == "thd":
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
1537
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
1538
                                        kv_inputs[i % 2], cu_seqlens_kv_padded, 0
1539
                                    )
1540
1541
                                else:
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
1542
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
1543
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
1544
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
1545
1546
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1560
1561
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1562
                                    max_seqlen_q,
1563
                                    max_seqlen_kv // 2,
1564
1565
1566
1567
1568
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1569
1570
                                )
                        else:
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
                            if pad_between_seqs_q:
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
                            if pad_between_seqs_kv:
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
                            else:
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1588
                            if use_fused_attention:
1589
1590
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
1591
                                    q_inputs[i % 2] = q[:, 1, ...].contiguous()
1592
                                    # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
1593
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
1594
                                        k.shape[0], -1, 2, *k.shape[-2:]
1595
                                    )
1596
1597
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
1598
                                    q_inputs[i % 2] = q[1].contiguous()
1599
1600
1601
1602
                                    # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        -1, k.shape[2], 2, *k.shape[-2:]
                                    )
1603
1604
                                elif qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
1605
1606
1607
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(
                                        q, cu_seqlens_q_padded, 1
                                    )
1608
1609
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1610
1611
1612
1613
1614
1615
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1616
                                    ).contiguous()
1617
1618
1619
1620
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q // 2,
1621
1622
1623
                                        max_seqlen_kv,
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
1624
                                        q_inputs[i % 2],
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
                                        (
                                            kv_inputs[i % 2][..., 0, :, :]
                                            if qkv_format in ["bshd", "sbhd"]
                                            else kv_inputs[i % 2][0]
                                        ),
                                        (
                                            kv_inputs[i % 2][..., 1, :, :]
                                            if qkv_format in ["bshd", "sbhd"]
                                            else kv_inputs[i % 2][1]
                                        ),
1635
1636
1637
1638
1639
1640
1641
1642
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type="padding" if padding else "no_mask",
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
1643
1644
1645
1646
                                        cu_seqlens_q_padded=(
                                            None
                                            if cu_seqlens_q_padded is None
                                            else cu_seqlens_q_padded // 2
1647
                                        ),
1648
                                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1649
                                    )
1650
                                )
1651
1652
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
1653
                            else:
1654
1655
                                if qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
1656
1657
1658
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(
                                        q, cu_seqlens_q_padded, 1
                                    )
1659
1660
                                else:
                                    # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
1661
                                    q_inputs[i % 2] = (
1662
                                        q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
1663
                                    )
1664
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
1665
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
1666
1667
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
1681
1682
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1683
                                    max_seqlen_q // 2,
1684
                                    max_seqlen_kv,
1685
1686
1687
1688
1689
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
1690
1691
                                )
                    else:
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
                        if pad_between_seqs_q:
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
                        else:
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
                        if pad_between_seqs_kv:
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
                        else:
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1709
                        if use_fused_attention:
1710
1711
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
1712
1713
1714
1715
1716
1717
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
1718
                                ).contiguous()
1719
1720
1721
1722
                            out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
1723
1724
1725
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1726
                                    q,
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
1737
1738
1739
1740
1741
1742
1743
1744
                                    TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
1745
1746
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1747
                                )
1748
                            )
1749
1750
                            if len(rest) > 0:
                                attn_biases[i] = rest[0]
1751
                        else:
1752
                            # [b, sq, np, hn] -> [b*sq, np, hn]
1753
                            q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
1754
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
                            kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                            (
                                _,
                                _,
                                _,
                                _,
                                out_per_step[i],
                                softmax_lse_per_step[i],
                                _,
                                rng_states[i],
                            ) = _flash_attn_forward(
                                q_inputs[i % 2],
                                kv_inputs[i % 2][0],
                                kv_inputs[i % 2][1],
1769
1770
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
1771
                                max_seqlen_q,
1772
                                max_seqlen_kv,
1773
1774
1775
1776
1777
                                dropout_p,
                                softmax_scale,
                                causal=False,
                                return_softmax=False,
                                **fa_optional_forward_kwargs,
1778
                            )
1779
1780
1781
1782

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
1783
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
1784

1785
1786
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
1787
                    softmax_lse_per_step[i - 1].squeeze_(-1)
1788

1789
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
1790
                    if i == 1:
1791
                        out = torch.zeros_like(q)
1792
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
1793
                        if causal and qkv_format != "thd":
1794
1795
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
1796
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
1797
                            )
1798
1799
1800
1801
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
1802
                    else:
1803
                        if qkv_format == "thd":
1804
                            tex.thd_second_half_lse_correction(
1805
1806
1807
1808
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
                                max_seqlen_q,
1809
                            )
1810
                        else:
1811
1812
1813
                            flash_attn_fwd_softmax_lse_correction(
                                softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
                            )
1814
1815

                if i < cp_size:
1816
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
1817
1818
1819
1820

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
1821
1822
        if qkv_format in ["bshd", "sbhd"]:
            seq_dim = qkv_format.index("s")
1823
        for i in range(cp_size):
1824
1825
1826
1827
1828
1829
            if qkv_format == "bshd":
                out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
                out_ = out[:, 1, ...]
            elif qkv_format == "sbhd":
                out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
                out_ = out[1]
1830

1831
            if i <= rank or not causal:
1832
                if qkv_format in ["bshd", "sbhd"]:
1833
1834
1835
1836
1837
1838
1839
                    flash_attn_fwd_out_correction(
                        out.view(*out_per_step[i].shape),
                        out_per_step[i],
                        seq_dim,
                        softmax_lse,
                        softmax_lse_per_step[i],
                    )
1840
                elif qkv_format == "thd":
1841
1842
1843
1844
1845
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
1846
                        cu_seqlens_q_padded,
1847
1848
                        False,
                    )
1849
1850
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
1851
            else:
1852
                if qkv_format in ["bshd", "sbhd"]:
1853
1854
1855
1856
1857
1858
1859
                    flash_attn_fwd_out_correction(
                        out_,
                        out_per_step[i],
                        seq_dim,
                        softmax_lse_[..., 1, :],
                        softmax_lse_per_step[i],
                    )
1860
                elif qkv_format == "thd":
1861
1862
1863
1864
1865
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
1866
                        cu_seqlens_q_padded,
1867
1868
                        True,
                    )
1869
1870
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
1871
1872

        kv = p2p_comm_buffers[-1]
1873
        if use_fused_attention:
1874
1875
1876
1877
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
1878
1879
        else:
            out = out.view(-1, *out.shape[-2:])
1880

1881
        ctx.save_for_backward(
1882
1883
1884
1885
            q,
            kv,
            out,
            softmax_lse,
1886
1887
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
1888
1889
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
1890
1891
            *rng_states,
            *attn_biases,
1892
        )
1893
1894
1895
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
1896
        ctx.total_tokens_kv = total_tokens_kv
1897
        ctx.max_seqlen_q = max_seqlen_q
1898
        ctx.max_seqlen_kv = max_seqlen_kv
1899
        ctx.softmax_scale = softmax_scale
1900
        ctx.qkv_format = qkv_format
1901
        ctx.attn_mask_type = attn_mask_type
1902
1903
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
1904
        ctx.deterministic = deterministic
1905
        ctx.use_fused_attention = use_fused_attention
1906
1907
1908
1909
1910
1911
        return out

    @staticmethod
    def backward(ctx, dout):
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
1912
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
1913
1914
1915
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1916
1917
1918
1919
1920
1921
        (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
        cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size]
        cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2]
        rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3]
        attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4]

1922
1923
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
1924
1925
1926
1927
        if ctx.qkv_format in ["bshd", "sbhd"]:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
1928

1929
        if attn_biases[0] is not None:
1930
1931
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
1932
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
1933
1934
1935
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
1936
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
1937
1938
1939
1940
            )
        else:
            attn_dbias = None

1941
        if causal:
1942
            if ctx.qkv_format == "thd":
1943
1944
1945
                softmax_lse_ = tex.thd_read_second_half_lse(
                    softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q
                )
1946
1947
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
1948
1949
1950
                softmax_lse_ = softmax_lse.view(
                    *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
                )
1951
1952
1953
1954
1955
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
                if ctx.use_fused_attention:
                    # [b, np, sq//2] -> [b, np, sq//2, 1]
                    softmax_lse_.unsqueeze_(-1)

1956
1957
1958
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
1959
1960
1961
1962
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        # Flash Attn outputs
        dq = torch.empty_like(q)
1963
1964
        if ctx.qkv_format == "thd" and causal:
            dq[cu_seqlens_q_padded[-1] :].fill_(0)
1965

1966
1967
1968
1969
        p2p_comm_buffers = [
            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
        ]
1970
1971
1972
        p2p_comm_buffers[0][0].copy_(kv)
        send_recv_reqs = []

1973
1974
1975
1976
1977
1978
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

1979
1980
1981
1982
1983
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

1984
1985
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
1986
1987
1988
            if i == 0:
                send_tensor = send_tensor[0]
                recv_tensor = recv_tensor[0]
1989
            if i == (cp_size - 1):
1990
1991
1992
                send_tensor = send_tensor[1]
                recv_tensor = recv_tensor[1]

1993
1994
1995
            send_recv_reqs = flash_attn_p2p_communicate(
                rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
            )
1996

1997
            kv = p2p_comm_buffers[i % 2][0]
1998
            # In reversed order of fwd
1999
            if causal:
2000
                if i == (cp_size - 1):
2001
                    if ctx.use_fused_attention:
2002
2003
2004
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
2005
2006
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
2007
2008
2009
2010
2011
2012
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
2013
2014
                            # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                            kv_ = kv.view(-1, *kv.shape[-4:])
2015
2016
2017
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
2018
2019
                        elif ctx.qkv_format == "thd":
                            q_, kv_, out_, dout_ = q, kv, out, dout
2020
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2021
                        if attn_dbias is not None:
2022
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2023
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2024
                            ctx.max_seqlen_q,
2025
2026
2027
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2028
                            q_,
2029
2030
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2031
2032
2033
2034
2035
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
2036
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
2037
2038
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2039
2040
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2041
                            qkv_layout=qkv_layout,
2042
                            attn_mask_type=ctx.attn_mask_type,
2043
                            attn_bias_type=ctx.attn_bias_type,
2044
2045
2046
2047
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
2048
                        dq_ = torch.zeros_like(q_)
2049
2050
2051
2052
2053
2054
2055
2056
2057
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, 0]
                        _flash_attn_backward(
2058
2059
2060
2061
2062
2063
2064
2065
2066
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2067
2068
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2069
                            ctx.max_seqlen_q,
2070
                            ctx.max_seqlen_kv,
2071
2072
2073
2074
2075
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            True,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2076
                        )
2077
                elif i >= (cp_size - rank - 1):
2078
                    if ctx.use_fused_attention:
2079
2080
2081
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
2082
2083
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                            kv_ = kv[:, 0, ...].contiguous()
2084
2085
2086
2087
2088
2089
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
2090
2091
                            # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                            kv_ = kv[0].contiguous()
2092
2093
2094
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
2095
2096
2097
                        elif ctx.qkv_format == "thd":
                            q_, out_, dout_ = q, out, dout
                            # [2, t, np, hn] -> [2, t/2, np, hn]
2098
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2099
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2100
                        if attn_dbias is not None:
2101
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2102
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2103
                            ctx.max_seqlen_q,
2104
2105
2106
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2107
                            q_,
2108
2109
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2110
2111
2112
2113
2114
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
2115
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
2116
2117
2118
2119
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
2120
2121
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2122
                            qkv_layout=qkv_layout,
2123
                            attn_mask_type="padding" if padding else "no_mask",
2124
                            attn_bias_type=ctx.attn_bias_type,
2125
2126
2127
2128
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
2129
                        dq_ = torch.zeros_like(q_)
2130
2131
                        if ctx.qkv_format == "thd":
                            # [2, t, np, hn] -> [2, t/2, np, hn]
2132
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
2133
2134
2135
                        else:
                            # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
2136
2137
2138
2139
2140
2141
2142
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
2143
2144
2145
2146
2147
2148
2149
2150
2151
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2152
2153
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2154
                            ctx.max_seqlen_q,
2155
                            ctx.max_seqlen_kv // 2,
2156
2157
2158
2159
2160
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2161
2162
2163
                        )
                else:
                    if ctx.use_fused_attention:
2164
2165
2166
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous()
2167
2168
                            # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
2169
2170
2171
2172
2173
2174
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous()
                            dout_ = dout[:, 1, ...].contiguous()
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            q_ = q[1].contiguous()
2175
2176
                            # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                            kv_ = kv.view(-1, *kv.shape[-4:])
2177
2178
2179
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            out_ = out[1].contiguous()
                            dout_ = dout[1].contiguous()
2180
2181
                        elif ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
2182
2183
2184
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
2185
                            kv_ = kv
2186
                        aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
2187
                        if attn_dbias is not None:
2188
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2189
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2190
                            ctx.max_seqlen_q // 2,
2191
2192
2193
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2194
                            q_,
2195
2196
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2197
2198
2199
2200
2201
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
2202
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
2203
2204
2205
2206
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2207
2208
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2209
                            qkv_layout=qkv_layout,
2210
                            attn_mask_type="padding" if padding else "no_mask",
2211
                            attn_bias_type=ctx.attn_bias_type,
2212
2213
                        )
                    else:
2214
2215
                        if ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
2216
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1)
2217
2218
2219
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
2220
                        dq_ = torch.zeros_like(q_)
2221
2222
2223
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
2224
                        if ctx.qkv_format == "thd":
2225
2226
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
2227
2228
2229
2230
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                            dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
2231
2232
2233
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
2234
2235
2236
2237
2238
2239
2240
2241
2242
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse_,
                            dq_,
                            dkv_[0],
                            dkv_[1],
2243
2244
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2245
                            ctx.max_seqlen_q // 2,
2246
                            ctx.max_seqlen_kv,
2247
2248
2249
2250
2251
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
2252
2253
2254
                        )
            else:
                if ctx.use_fused_attention:
2255
                    aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2256
                    if attn_dbias is not None:
2257
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2258
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2259
                        ctx.max_seqlen_q,
2260
2261
2262
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
2263
                        q,
2264
2265
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
2266
2267
2268
2269
2270
                        out,
                        dout,
                        TE_DType[q.dtype],
                        TE_DType[kv.dtype],
                        aux_ctx_tensors,
2271
                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
2272
2273
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2274
2275
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
2276
                        qkv_layout=qkv_layout,
2277
                        attn_mask_type=ctx.attn_mask_type,
2278
                        attn_bias_type=ctx.attn_bias_type,
2279
2280
2281
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
2282
                    q_ = q.view(-1, *q.shape[-2:])
2283
                    dq_ = torch.zeros_like(q_)
2284
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
2285
2286
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
2287
                    # [b, sq, np, hn] -> [b*sq, np, hn]
2288
2289
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
2290
2291
                    if _flash_attn_2_3_plus:
                        fa_optional_backward_kwargs["window_size"] = [-1, -1]
2292
                    _flash_attn_backward(
2293
2294
2295
2296
2297
2298
2299
2300
2301
                        dout_,
                        q_,
                        kv_[0],
                        kv_[1],
                        out_,
                        softmax_lse,
                        dq_,
                        dkv_[0],
                        dkv_[1],
2302
2303
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
2304
                        ctx.max_seqlen_q,
2305
                        ctx.max_seqlen_kv,
2306
2307
2308
                        ctx.dropout_p,
                        ctx.softmax_scale,
                        False,
2309
                        rng_state=rng_states[cp_size - i - 1],
2310
                        **fa_optional_backward_kwargs,
2311
2312
                    )

2313
            if i >= (cp_size - rank - 1) or not causal:
2314
2315
2316
2317
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
2318
2319
2320
2321
2322
2323
                if ctx.qkv_format == "bshd":
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
                elif ctx.qkv_format == "sbhd":
                    # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
                    dq_ = dq_.view(-1, *dq.shape[-3:])
2324

2325
            if causal:
2326
                if i > (cp_size - rank - 1):
2327
                    dq.add_(dq_)
2328
2329
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
2330
2331
                        dq.copy_(dq_)
                    else:
2332
2333
2334
2335
2336
2337
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
2338
                        elif ctx.qkv_format == "thd":
2339
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
2340
                elif i > 0:
2341
2342
2343
2344
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
2345
                    elif ctx.qkv_format == "thd":
2346
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
2347
                else:
2348
2349
2350
2351
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
2352
                    elif ctx.qkv_format == "thd":
2353
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
2354
2355
2356
2357
2358
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
2359

2360
            if attn_dbias is not None:
2361
                idx = (rank + i + 1) % cp_size
2362
                if i == (cp_size - 1) or not causal:
2363
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
2364
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2365
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
2366
2367
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
2368
2369
2370
2371
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
2372
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2373
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
2374
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
2375

2376
2377
2378
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
2379

2380
            dkv = p2p_comm_buffers[(i + 1) % 2][1]
2381
2382
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
2383
2384
2385
2386
                if ctx.qkv_format in ["bshd", "sbhd"]:
                    # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
2387
            if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
2388
2389
2390
2391
2392
2393
                if ctx.qkv_format == "bshd":
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                elif ctx.qkv_format == "sbhd":
                    # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
2394
2395
2396
2397
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
2398

2399
            if causal:
2400
                if i == (cp_size - 1):
2401
                    if rank == 0:
2402
2403
2404
2405
2406
2407
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
2408
                        elif ctx.qkv_format == "thd":
2409
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
2410
2411
                    else:
                        dkv.add_(dkv_)
2412
2413
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
2414
2415
2416
2417
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
2418
                        elif ctx.qkv_format == "thd":
2419
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
2420
                    else:
2421
2422
2423
2424
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
2425
                        elif ctx.qkv_format == "thd":
2426
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
2427
2428
2429
2430
2431
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
2432
2433
2434
2435
2436
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

2437
        if causal:
2438
2439
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
2440
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
2441
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
2442
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
2443
2444
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
2445
                dq = dq.view(-1, *dq.shape[-3:])
2446
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
2447
2448
2449
2450
2451
2452
2453
2454
2455
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

        if ctx.qkv_format == "thd":
            dkv_ = torch.empty(
                2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device
            )
            dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv)
            dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
            dkv = dkv_
2456
2457
2458
2459
2460

        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
        return (
            None,
            dq,
            dkv[0],
            dkv[1],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            attn_dbias,
            None,
            None,
        )
2484
2485
2486


def attn_forward_func_with_cp(
2487
2488
2489
2490
2491
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
2492
    cu_seqlens_kv,
2493
    max_seqlen_q,
2494
    max_seqlen_kv,
2495
2496
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
2508
2509
) -> torch.Tensor:
    """Attention implementation with context parallelism"""
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert (
        qkv_format != "thd"
        or not use_fused_attention
        or attn_mask_type in ["padding", "padding_causal"]
    ), (
        f"Context parallelism is not supported for {attn_mask_type} mask type and "
        f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
    )
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
2530
2531
2532
    assert (
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
    ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!"
2533
    out = AttnFuncWithCP.apply(
2534
2535
2536
2537
2538
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
2539
        cu_seqlens_kv,
2540
        max_seqlen_q,
2541
        max_seqlen_kv,
2542
2543
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
        dropout_p,
        cp_group,
        cp_global_ranks,
        cp_stream,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
2555
2556
2557
2558
    )
    return out


2559
2560
2561
2562
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
2563

2564
2565
2566
    def __init__(
        self,
        dim: int,
2567
        rotary_percent: float = 1.0,
2568
2569
2570
2571
2572
2573
2574
2575
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
2576
2577
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
2578
2579
2580
2581
2582
2583
2584
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
2585
2586
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
2587
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
2588
2589
2590
2591
2592
2593
2594
        inv_freq = 1.0 / (
            10000
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
2595
        self.register_buffer("inv_freq", inv_freq)
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
2609
2610
2611
2612
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
2613

2614
2615
2616
2617
2618
2619
2620
2621
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
2622
2623
2624
2625
2626
2627
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

2628
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
2629
2630
2631
2632
2633
2634
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
2653
2654
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
2655
2656
2657
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
2658
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
2669
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


2685
2686
2687
2688
2689
2690
2691
2692
2693
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


2694
def apply_rotary_pos_emb(
2695
2696
2697
2698
2699
2700
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
2701
    """
2702
    Apply rotary positional embedding tensor to the input tensor.
2703

2704
2705
2706
    Parameters
    ----------
    t: torch.Tensor
2707
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
2720
    """
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

2732
2733
2734
2735
2736
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
2737
2738
2739
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
2740
    freqs = freqs[:cur_seq_len]
2741
    if tensor_format == "bshd":
2742
2743
2744
2745
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
2746

2747
2748
2749
2750
2751
2752
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
2753
    t = (t * cos_) + (_rotate_half(t) * sin_)
2754
2755
2756
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
2757
class _SplitAlongDim(torch.autograd.Function):
2758
2759
2760
    """"""

    @staticmethod
2761
2762
2763
2764
2765
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
2766
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
2767
2768
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
2769
        if isinstance(mixed_x_layer, Float8Tensor):
2770
2771
2772
2773
2774
2775
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
                    data=x,
                )
                for x in torch.split(
2776
2777
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
2778
2779
2780
2781
                    dim=split_dim,
                )
            )
        return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
2782
2783

    @staticmethod
2784
    def backward(ctx, *grad_outputs):
2785
2786
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
2787
2788
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
2789
2790
2791
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
2792
2793
2794
2795
2796
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

2797
2798
2799
2800
2801
2802
2803
2804
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
2805
2806
2807
2808
2809
2810
2811
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
2812
2813
2814
                    noop_ok = False
                    break
            if noop_ok:
2815
2816
2817
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
2818
2819
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
2820
2821
2822
2823
2824
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
2825
2826
2827
2828
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
2829
2830
2831
2832
2833
2834
2835
            return (
                Float8Tensor.make_like(
                    grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim)
                ),
                None,
                None,
            )
2836
2837
        noop_ok = True
        strides = grad_outputs[0].stride()
2838
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
2839
        shape = list(grad_outputs[0].shape)
2840
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
2841
2842
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
2843
2844
2845
2846
2847
2848
2849
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
2850
2851
2852
                noop_ok = False
                break
        if noop_ok:
2853
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
2854
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
2855
            new_shape[split_dim] = sum(split_sizes)
2856
2857
2858
2859
2860
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
2861
            )
cyanguwa's avatar
cyanguwa committed
2862
            return ret, None, None
2863

2864
        return torch.cat(grad_outputs, dim=split_dim), None, None
2865
2866
2867
2868
2869
2870
2871
2872
2873


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
2874
        softmax_scale: float,
2875
2876
2877
2878
2879
2880
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

2881
        self.softmax_scale = softmax_scale
2882
2883
2884
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

2885
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
2886
2887
2888
2889
2890
2891

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

2892
2893
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
2894
2895
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
2896

2897
2898
2899
2900
2901
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2902
        qkv_layout: str = "sbh3d",
2903
2904
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
2905
        attn_mask_type: str = "causal",
2906
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2907
2908
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
2909
        alibi_slopes: Optional[torch.Tensor] = None,
2910
    ) -> torch.Tensor:
2911
        """Unfused attention fprop"""
2912
2913
2914
2915
2916
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
2917
            # convert to sbhd and use sbhd implementation for now
2918
2919
2920
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
2921

2922
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
2923
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
2924
2925
2926
2927
2928
2929
2930
2931
2932

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

2933
        if key_layer.shape[2] != query_layer.shape[2]:
2934
2935
2936
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
2937
            key_layer = key_layer.repeat_interleave(
2938
2939
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
2940
            value_layer = value_layer.repeat_interleave(
2941
2942
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
2943

2944
        # [sq, b, np, hn] -> [sq, b * np, hn]
2945
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
2946
2947
2948
2949
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
2950
2951
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
2952
2953
2954
2955
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
2956
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
2957
2958
2959
            device=torch.cuda.current_device(),
        )

2960
2961
2962
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

2963
        scale = self.softmax_scale
2964
        if apply_qk_layer_scaling:
2965
            scale /= self.layer_number
2966
2967

        # Raw attention scores. [b * np, sq, sk]
2968
2969
2970
2971
2972
2973
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
2974
                alpha=scale,
2975
2976
2977
2978
2979
2980
2981
2982
            )

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
2983
2984
2985
2986
            matmul_result = (
                matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias
            ).view(-1, output_size[2], output_size[3])
2987
            matmul_result *= scale
2988

2989
2990
2991
2992
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
2993
                _, core_attention_bias = get_alibi(
2994
2995
2996
2997
2998
                    output_size[1],
                    output_size[2],
                    output_size[3],
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
2999
                )
3000
3001
3002
3003
3004
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
3005
                alpha=scale,
3006
            )
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
            matmul_result = (
                (
                    matmul_result.view(
                        output_size[0], output_size[1], output_size[2], output_size[3]
                    )
                    + core_attention_bias
                )
                .view(-1, output_size[2], output_size[3])
                .to(dtype=query_layer.dtype)
            )
3017
3018
3019
3020
3021
3022

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
3023
        attention_probs = self.scale_mask_softmax(
3024
3025
            attention_scores, attention_mask, attn_mask_type, softmax_scale
        )
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
3042
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
3043
3044

        # change view [b * np, sq, sk]
3045
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
3046
3047
3048
3049
3050
3051
3052

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

3053
        if qkv_format == "sbhd":
3054
3055
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
3056

3057
3058
3059
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

3060
        if qkv_format == "bshd":
3061
3062
3063
3064
3065
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
3066
3067
3068
3069
3070
3071

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
3072
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
3073
3074

    @staticmethod
3075
3076
3077
3078
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
3079
        value_layer: torch.Tensor,
3080
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
3092
3093
3094
3095
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
3096
        dv: torch.Tensor,
3097
3098
3099
3100
3101
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

3102

3103
def get_qkv_layout(
3104
3105
3106
3107
3108
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
3109
    """Get qkv layout.
3110

3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
        `d` head size, and `t` the total number of sequences in a batch, i.e.
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
3139

3140
3141
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
3142

3143
3144
3145
3146
3147
3148
3149
3150
3151
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
        stride = k.stride()
3152
3153
3154
        check_strides_kv = torch.equal(
            torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1]
        )
3155
3156
3157
3158

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
3159
        check_shapes_kv = shape[:-1] == v.shape[:-1]
3160
3161

        last_dim_size = q.shape[-1]
3162
3163
3164
        check_last_dim_offsets_qkv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
3165
        last_dim_size = k.shape[-1]
3166
3167
3168
        check_last_dim_offsets_kv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])
        )
3169
3170

        last_two_dims_size = q.shape[-1] * q.shape[-2]
3171
3172
3173
        check_last_two_dims_offsets_qkv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
3174
        last_two_dims_size = k.shape[-1] * k.shape[-2]
3175
3176
3177
        check_last_two_dims_offsets_kv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])
        )
3178

3179
3180
3181
3182
        if (
            check_ptrs_qkv
            and check_strides_qkv
            and check_shapes_qkv
3183
            and check_last_two_dims_offsets_qkv
3184
3185
            and not check_last_dim_offsets_qkv
        ):
3186
            # sb3hd, bs3hd, t3hd
3187
3188
3189
3190
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
        elif (
            check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv
        ):
3191
            # sbh3d, bsh3d, th3d
3192
3193
3194
3195
3196
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
        elif (
            check_ptrs_kv
            and check_strides_kv
            and check_shapes_kv
3197
            and check_last_two_dims_offsets_kv
3198
3199
            and not check_last_dim_offsets_kv
        ):
3200
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
3201
3202
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv:
3203
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
3204
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
3205
3206
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
3207
            qkv_layout = "_".join(list([qkv_format]) * 3)
3208
        else:
3209
            qkv_layout = "not_supported"
3210
3211
3212
3213

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
3214
    if qkv_layout == "not_supported":
3215
3216
3217
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
3218
    if qkv_layout == "not_supported":
3219
3220
        raise Exception("The provided qkv memory layout is not supported!")

3221
    return qkv_layout, q, k, v
3222

3223

3224
def check_set_window_size(
3225
3226
3227
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
3228
3229
3230
3231
3232
3233
3234
3235
    """Check if sliding window size is compliant with attention mask type.
    If not, set it to the appropriate size.

         attn_mask_type                              |   window_size
    -------------------------------------------------------------------------
    no_mask, padding, arbitrary                      | (-1, -1) or (>=0, >=0)
    causal, padding_causal                           | (-1,  0) or (>=0, 0)
    causal_bottom_right, padding_causal_bottom_right | (-1,  0) or (>=0, 0)
3236
    """
3237
    orig_window_size = window_size
3238
    if "causal" in attn_mask_type:
3239
3240
3241
        if orig_window_size is None or (
            orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
        ):
3242
            window_size = (-1, 0)
3243
3244
3245
3246
3247
3248
3249
3250
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
        elif orig_window_size[0] >= 0:
            window_size = (orig_window_size[0], 0)
            warnings.warn(
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
3251
        else:
3252
3253
3254
3255
3256
3257
3258
            assert False, (
                "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
            )
    elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
        if orig_window_size is None or (
            orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
        ):
3259
            window_size = (-1, -1)
3260
3261
3262
            warnings.warn(
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
3263
        elif orig_window_size[0] < 0 or orig_window_size[1] < 0:
3264
3265
3266
3267
3268
            assert False, (
                "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
            )
    else:
        assert False, "Invalid attn_mask_type: " + attn_mask_type
3269
    return window_size
3270

3271

3272
class FlashAttention(torch.nn.Module):
3273
    """Dot product attention, using HazyResearch flash-attn package:
3274
    https://github.com/Dao-AILab/flash-attention
3275
3276
3277
3278
    """

    def __init__(
        self,
3279
        softmax_scale: float,
3280
3281
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
3282
3283
        attention_type: str = "self",
        layer_number: Optional[int] = None,
3284
        deterministic: bool = False,
3285
3286
3287
3288
3289
3290
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
3291
3292
3293
        assert (
            _flash_attn_version <= _flash_attn_max_version
        ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
3294

3295
        self.softmax_scale = softmax_scale
3296
3297
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
3298
3299
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
3300
        self.deterministic = deterministic
3301
3302
3303
3304
3305
3306

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3307
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3308
3309
3310
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
3311
3312
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
3313
        attn_mask_type: str = "causal",
3314
        window_size: Optional[Tuple[int, int]] = None,
3315
        alibi_slopes: Optional[torch.Tensor] = None,
3316
        cp_group: Optional[dist_group_type] = None,
3317
        cp_global_ranks: List[int] = None,
3318
        cp_stream: torch.cuda.Stream = None,
3319
3320
3321
3322
    ) -> torch.Tensor:
        """flash-attn fprop"""

        assert (
3323
3324
3325
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
3326
        ), "FlashAttention currently only supports FP16 and BF16."
3327
3328
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
3329
        ), "FlashAttention currently only supports CUDA tensors."
3330
3331
        assert (
            qkv_layout in QKVLayouts
3332
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
3333

3334
3335
        cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
        context_parallel = cp_size > 1
3336

3337
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
3338

3339
        if qkv_format == "sbhd":
3340
            # For now just 128, will make it more general in the future
3341
3342
3343
3344
3345
3346
3347
3348
            if (
                query_layer.shape[-1] == 128
                and query_layer.shape[0] * query_layer.shape[1] >= 512
                and qkv_layout == "sbh3d"
            ):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                    query_layer, key_layer, value_layer
                )
3349
            else:
3350
3351
3352
3353
3354
3355
3356
                query_layer, key_layer, value_layer = [
                    x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        elif qkv_format in ["bshd", "thd"]:
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
3357

3358
        batch_size = query_layer.shape[0]
3359

3360
        if qkv_format in ["sbhd", "bshd"]:
3361
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
3362
3363
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
3364
3365
3366
3367
3368
3369
3370
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

3371
            if "padding" in attn_mask_type:
3372
                assert not context_parallel, "Padding mask not supported with context parallelism!"
3373
3374
3375
3376
3377

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
3378
                    if cu_seqlens_q is None:
3379
3380
3381
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
3382
3383
3384
3385
3386
3387
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
3388
3389
                    )
                else:
3390
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
3391
3392
3393
3394
3395
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
3396
3397
3398
3399
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
3400
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
3401
            else:
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
3415
3416
3417
3418
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
3419
3420
3421
3422
3423
3424
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
3425

3426
        if context_parallel:
3427
3428
3429
3430
            assert window_size in (
                (-1, -1),
                (-1, 0),
            ), "Sliding window attention is not supported with context parallelism."
3431
3432
3433
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
3434
            with self.attention_dropout_ctx():
3435
                output = attn_forward_func_with_cp(
3436
3437
3438
3439
3440
3441
3442
3443
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
3444
3445
                    cu_seqlens_q,
                    cu_seqlens_kv,
3446
                    self.attention_dropout if self.training else 0.0,
3447
3448
3449
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
3450
                    softmax_scale=self.softmax_scale,
3451
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
3452
                    attn_mask_type=attn_mask_type,
3453
                    deterministic=self.deterministic,
3454
3455
                )
        else:
3456
3457

            from .cpu_offload import CPUOffloadEnabled
3458

3459
3460
3461
3462
3463
3464
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

3465
            with self.attention_dropout_ctx():
3466
                fa_optional_forward_kwargs = {}
3467
3468
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
3469
3470
3471
3472
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
3473
3474
                if _flash_attn_2_5_7_plus:
                    fa_optional_forward_kwargs["block_table"] = None
3475
                output = flash_attn_forward_func(
3476
3477
3478
3479
3480
3481
3482
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
3483
                    self.attention_dropout if self.training else 0.0,
3484
3485
                    softmax_scale=self.softmax_scale,
                    causal="causal" in attn_mask_type,
3486
                    **fa_optional_forward_kwargs,
3487
                )
3488

3489
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
3490
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
3491

3492
        if qkv_format == "sbhd":
3493
            # (bs)hd -> bs(hd) -> sb(hd)
3494
3495
3496
            output = (
                output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous()
            )
3497
        elif qkv_format == "bshd":
3498
            # (bs)hd -> bs(hd)
3499
            output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous()
3500
        elif qkv_format == "thd":
3501
3502
            # thd -> t(hd)
            output = output.view(output.shape[0], -1).contiguous()
3503
3504

        return output
3505

3506

3507
def _combine_tensors(
3508
3509
3510
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
3511
3512
3513
3514
3515
3516
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
3517
    new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
3518
    if isinstance(tensors[0], Float8Tensor):
3519
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
3520
3521
3522
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
3523
3524
3525
3526
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor)
3527
    else:
3528
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
3529
        combined_tensor.set_(
3530
3531
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
        )
3532
3533

    return combined_tensor
3534

3535

3536
3537
3538
3539
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
3540
3541
3542
3543
3544
    def forward(
        ctx,
        is_training,
        max_seqlen,
        cu_seqlens,
3545
        cu_seqlens_padded,
3546
3547
3548
3549
3550
3551
3552
3553
3554
        qkv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
3555
        window_size,
3556
3557
3558
3559
3560
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
3561
        deterministic,
3562
    ):
3563
        logger = logging.getLogger("FusedAttnFunc_qkvpacked")
3564
        if fp8:
3565
            logger.debug("Running forward in FP8")
3566
            if fp8_meta["recipe"].fp8_mha:
3567
                assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
3568
3569
3570
3571
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
3572
3573
3574
3575
3576
            qkv_group = len(qkv_layout.split("_"))
            assert qkv_group == 1, (
                "qkv layout should conform to 3hd or h3d, e.g. sb3hd,                 but found"
                f" {qkv_layout}."
            )
3577
3578
3579
3580
            if fp8_meta["recipe"].fp8_mha:
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
3581
3582
3583
                qkv_fp8 = cast_to_fp8(
                    qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(qkv.shape)
3584
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
3585
3586
3587
3588
3589
3590
3591
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
3592
                cu_seqlens_padded,
3593
3594
3595
3596
3597
3598
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
3599
3600
3601
3602
3603
3604
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
3605
                window_size,
3606
3607
                rng_gen,
            )
3608
            if fp8_meta["recipe"].fp8_mha:
3609
3610
                out_ret = Float8Tensor(
                    data=out_fp8,
3611
3612
3613
3614
3615
3616
3617
3618
3619
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3620
3621
3622
3623
3624
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
3625
3626
3627
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
3628
3629
                qkv = cast_from_fp8(
                    qkv_c._data,
3630
                    fp8_meta["scaling_fwd"],
3631
3632
3633
3634
                    META_QKV,
                    fp8_dtype_forward,
                    TE_DType[qkv.dtype],
                ).view(qkv.shape)
3635
3636
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3637
3638
3639
3640
3641
3642
3643
3644
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
            fp8_tensors = (
                qkv_fp8,
                out_fp8,
3645
                fp8_meta["scaling_fwd"].scale.clone(),
3646
3647
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
3648
        else:
3649
            logger.debug("Running forward in %s", qkv.dtype)
3650
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
3651
3652
3653
3654
3655
3656
3657
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
3658
                cu_seqlens_padded,
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
3671
                window_size,
3672
3673
                rng_gen,
            )
3674
3675
3676
3677
3678
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
3679
        ctx.save_for_backward(
3680
            *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
3681
        )
3682
        ctx.fp8_meta = fp8_meta
3683
3684
3685
3686
3687
3688
3689
3690
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
3691
        ctx.window_size = window_size
3692
        ctx.fused_attention_backend = (
3693
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
3694
        )
3695
        ctx.use_FAv2_bwd = use_FAv2_bwd
3696
        ctx.deterministic = deterministic
3697

3698
        return out_ret
3699
3700
3701

    @staticmethod
    def backward(ctx, d_out):
3702
        logger = logging.getLogger("FusedAttnFunc_qkvpacked")
3703
        if ctx.fp8_meta["recipe"].fp8_mha:
3704
3705
3706
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
3707
3708
3709
            d_out_f8tensor = d_out
            d_out = d_out._data

3710
        d_out = d_out.contiguous()
3711
3712
3713
3714
        (
            qkv,
            out,
            cu_seqlens,
3715
            cu_seqlens_padded,
3716
3717
3718
3719
3720
3721
            qkv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
3722
3723
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
3724
        if ctx.use_FAv2_bwd:
3725
            softmax_lse, rng_state = aux_ctx_tensors
3726
3727
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
3728
3729
3730
            d_out, q, k, v, out = [
                maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
            ]
3731
            flash_attn_cuda_bwd(
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
3750
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dqkv[:, 0],
                dqkv[:, 1],
                dqkv[:, 2],
                cu_seqlens,
                cu_seqlens,
                ctx.max_seqlen,
                ctx.max_seqlen,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
3751
            )
3752
            dqkv = dqkv[..., : d_out.shape[-1]]
3753
        else:
3754
3755
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
3756
                    logger.debug("Running backward in FP8")
3757
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
3758
                    fp8_dtype_backward = get_fp8_te_dtype(
3759
3760
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
3761
3762
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
3763
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
3764
3765
3766
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
3767
3768
3769
3770
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
3771
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
3772
3773
3774
3775
3776
3777
3778
3779
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
3780
                        ctx.fused_attention_backend,
3781
                        cu_seqlens_padded,
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
3798
3799
                        ctx.window_size,
                        ctx.deterministic,
3800
                    )
3801
                    if ctx.fp8_meta["recipe"].fp8_mha:
3802
3803
                        dqkv = Float8Tensor(
                            data=dqkv_fp8,
3804
3805
3806
3807
3808
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
3809
                        )
3810
                    else:
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
                        dqkv_c_fp8 = dqkv_fp8.view(
                            -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                        )
                        dqkv = cast_from_fp8(
                            dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dqkv_fp8.shape)
3821
                else:
3822
                    logger.debug("Running backward in %s", qkv.dtype)
3823
3824
3825
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
3826
3827
3828
3829
3830
3831
3832
3833
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
3834
                        ctx.fused_attention_backend,
3835
                        cu_seqlens_padded,
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
3852
3853
                        ctx.window_size,
                        ctx.deterministic,
3854
                    )
3855

3856
3857
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
            return (
                None,
                None,
                None,
                None,
                dqkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
3879
3880
                None,
                None,
3881
            )
3882
        # else, return (dqkv, dbias)
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
        return (
            None,
            None,
            None,
            None,
            dqkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3904
3905
            None,
            None,
3906
        )
3907

3908

3909
3910
3911
3912
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
3913
3914
3915
3916
3917
3918
3919
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
3920
3921
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
3922
3923
3924
3925
3926
3927
3928
3929
3930
3931
        q,
        kv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
3932
        window_size,
3933
3934
3935
3936
3937
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
3938
        deterministic,
3939
    ):
3940
        logger = logging.getLogger("FusedAttnFunc_kvpacked")
3941
        if fp8:
3942
            logger.debug("Running forward in FP8")
3943
            if fp8_meta["recipe"].fp8_mha:
3944
3945
3946
                assert isinstance(q, Float8Tensor) and isinstance(
                    kv, Float8Tensor
                ), "q/kv must be Float8Tensors for FP8 MHA."
3947
3948
3949
3950
3951
3952
3953
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
3954
3955
3956
3957
3958
3959
3960
3961
                qkv_group = len(qkv_layout.split("_"))
                assert qkv_group == 2, (
                    "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd,              "
                    f"       but found {qkv_layout}."
                )
                q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
                    q.shape
                )
3962
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
3963
3964
3965
                kv_fp8 = cast_to_fp8(
                    kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(kv.shape)
3966
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                kv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
3977
3978
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
3979
3980
3981
3982
3983
3984
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
3985
3986
3987
3988
3989
3990
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
3991
                window_size,
3992
3993
                rng_gen,
            )
3994
            if fp8_meta["recipe"].fp8_mha:
3995
3996
                out_ret = Float8Tensor(
                    data=out_fp8,
3997
3998
3999
4000
4001
4002
4003
4004
4005
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4006
4007
4008
4009
4010
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
4011
4012
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
4013
4014
4015
                q = cast_from_fp8(
                    q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype]
                ).view(q.shape)
4016
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4017
4018
                kv = cast_from_fp8(
                    kv_c._data,
4019
                    fp8_meta["scaling_fwd"],
4020
4021
4022
4023
                    META_QKV,
                    fp8_dtype_forward,
                    TE_DType[kv.dtype],
                ).view(kv.shape)
4024
4025
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4026
4027
4028
4029
4030
4031
4032
4033
4034
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
            fp8_tensors = (
                q_fp8,
                kv_fp8,
                out_fp8,
4035
                fp8_meta["scaling_fwd"].scale.clone(),
4036
4037
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
4038
        else:
4039
            logger.debug("Running forward in %s", q.dtype)
4040
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
4041
4042
4043
4044
4045
4046
4047
4048
4049
4050
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                kv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
4051
4052
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
4063
4064
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4065
                window_size,
4066
4067
                rng_gen,
            )
4068
4069
4070
4071
4072
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
4073
4074
4075
4076
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
4077
4078
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4079
4080
4081
            *fp8_tensors,
            *aux_ctx_tensors,
        )
4082
        ctx.fp8_meta = fp8_meta
4083
4084
4085
4086
4087
4088
4089
4090
4091
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
4092
        ctx.window_size = window_size
4093
        ctx.fused_attention_backend = (
4094
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
4095
        )
4096
        ctx.use_FAv2_bwd = use_FAv2_bwd
4097
        ctx.deterministic = deterministic
4098

4099
        return out_ret
4100
4101
4102

    @staticmethod
    def backward(ctx, d_out):
4103
        logger = logging.getLogger("FusedAttnFunc_kvpacked")
4104
        if ctx.fp8_meta["recipe"].fp8_mha:
4105
4106
4107
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
4108
4109
4110
            d_out_f8tensor = d_out
            d_out = d_out._data

4111
        d_out = d_out.contiguous()
4112
4113
4114
4115
4116
4117
        (
            q,
            kv,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
4118
4119
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4120
4121
4122
4123
4124
4125
4126
            q_fp8,
            kv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
4127
4128
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
4129
        if ctx.use_FAv2_bwd:
4130
            softmax_lse, rng_state = aux_ctx_tensors
4131
4132
4133
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
4134
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
4135
            flash_attn_cuda_bwd(
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dkv[:, 0],
                dkv[:, 1],
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
4155
            )
4156
4157
            dq = dq[..., : d_out.shape[-1]]
            dkv = dkv[..., : d_out.shape[-1]]
4158
        else:
4159
4160
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
4161
                    logger.debug("Running backward in FP8")
4162
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
4163
                    fp8_dtype_backward = get_fp8_te_dtype(
4164
4165
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
4166
4167
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
4168
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
4169
4170
4171
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
4172
4173
4174
4175
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
4176
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
4177
4178
4179
4180
4181
4182
4183
4184
4185
4186
4187
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        kv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
4188
                        ctx.fused_attention_backend,
4189
4190
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4191
4192
4193
4194
4195
4196
4197
4198
4199
4200
4201
4202
4203
4204
4205
4206
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4207
4208
                        ctx.window_size,
                        ctx.deterministic,
4209
                    )
4210
                    if ctx.fp8_meta["recipe"].fp8_mha:
4211
4212
                        dq = Float8Tensor(
                            data=dq_fp8,
4213
4214
4215
4216
4217
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4218
4219
4220
                        )
                        dkv = Float8Tensor(
                            data=dkv_fp8,
4221
4222
4223
4224
4225
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4226
                        )
4227
4228
4229
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
4241
4242
4243
4244
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(
                            -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                        )
                        dkv = cast_from_fp8(
                            dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dkv_fp8.shape)
4245
                else:
4246
                    logger.debug("Running backward in %s", q.dtype)
4247
4248
4249
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        kv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
4261
                        ctx.fused_attention_backend,
4262
4263
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4264
4265
4266
4267
4268
4269
4270
4271
4272
4273
4274
4275
4276
4277
4278
4279
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4280
4281
                        ctx.window_size,
                        ctx.deterministic,
4282
                    )
4283

4284
4285
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
4299
4300
4301
4302
4303
4304
4305
4306
4307
4308
4309
4310
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
4311
4312
                None,
                None,
4313
            )
4314
        # else, return (dqkv, dbias)
4315
4316
4317
4318
4319
4320
4321
4322
4323
4324
4325
4326
4327
4328
4329
4330
4331
4332
4333
4334
4335
4336
4337
4338
4339
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4340
4341
            None,
            None,
4342
4343
        )

4344

4345
4346
4347
4348
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
4349
4350
4351
4352
4353
4354
4355
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
4356
4357
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
4358
4359
4360
4361
4362
4363
4364
4365
4366
4367
4368
        q,
        k,
        v,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
4369
        window_size,
4370
4371
4372
4373
4374
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
4375
        deterministic,
4376
    ):
4377
        logger = logging.getLogger("FusedAttnFunc")
4378
        if fp8:
4379
            logger.debug("Running forward in FP8")
4380
4381
4382
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
4383
4384
                assert (
                    isinstance(q, Float8Tensor)
4385
                    and isinstance(k, Float8Tensor)
4386
4387
                    and isinstance(v, Float8Tensor)
                ), "q/k/v must be Float8Tensors for FP8 MHA."
4388
4389
4390
4391
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4392
                qkv_group = len(qkv_layout.split("_"))
4393
                if qkv_group == 1:
4394
4395
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
4396
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
4397
4398
4399
4400
                    qkv_fp8 = cast_to_fp8(
                        qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1])
4401
4402
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
4403
4404
4405
4406
4407
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
4408
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4409
4410
4411
4412
                    kv_fp8 = cast_to_fp8(
                        kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1])
4413
4414
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
4415
4416
4417
4418
4419
4420
4421
4422
4423
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    k_fp8 = cast_to_fp8(
                        k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(k.shape)
                    v_fp8 = cast_to_fp8(
                        v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(v.shape)
4424
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
4425
4426
4427
4428
4429
4430
4431
4432
4433
4434
4435
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
4436
4437
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4438
4439
4440
4441
4442
4443
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
4444
4445
4446
4447
4448
4449
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4450
                window_size,
4451
4452
                rng_gen,
            )
4453
            if fp8_meta["recipe"].fp8_mha:
4454
4455
                out_ret = Float8Tensor(
                    data=out_fp8,
4456
4457
4458
4459
4460
4461
4462
4463
4464
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4465
4466
4467
4468
4469
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
4470
4471
4472
4473
            out_save = out_ret

            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4474
                qkv_group = len(qkv_layout.split("_"))
4475
                if qkv_group == 1:
4476
4477
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
4478
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
4479
4480
                    qkv_no_fp8 = cast_from_fp8(
                        qkv_c._data,
4481
                        fp8_meta["scaling_fwd"],
4482
4483
4484
4485
4486
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[qkv.dtype],
                    ).view(qkv.shape)
                    q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1])
4487
4488
                    q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                if qkv_group == 2:
4489
4490
                    q = cast_from_fp8(
                        q._data,
4491
                        fp8_meta["scaling_fwd"],
4492
4493
4494
4495
4496
4497
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
4498
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4499
4500
                    kv_no_fp8 = cast_from_fp8(
                        kv_c._data,
4501
                        fp8_meta["scaling_fwd"],
4502
4503
4504
4505
4506
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[kv.dtype],
                    ).view(kv.shape)
                    k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1])
4507
4508
                    k, v = [x.squeeze(dim) for x in [k, v]]
                if qkv_group == 3:
4509
4510
                    q = cast_from_fp8(
                        q._data,
4511
                        fp8_meta["scaling_fwd"],
4512
4513
4514
4515
4516
4517
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    k = cast_from_fp8(
                        k._data,
4518
                        fp8_meta["scaling_fwd"],
4519
4520
4521
4522
4523
4524
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[k.dtype],
                    ).view(k.shape)
                    v = cast_from_fp8(
                        v._data,
4525
                        fp8_meta["scaling_fwd"],
4526
4527
4528
4529
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[v.dtype],
                    ).view(v.shape)
4530
4531
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
4532
4533
4534
4535
4536
4537
4538
4539
4540
4541
4542
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)

            fp8_tensors = (
                q_fp8,
                k_fp8,
                v_fp8,
                out_fp8,
4543
                fp8_meta["scaling_fwd"].scale.clone(),
4544
4545
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
4546
        else:
4547
            logger.debug("Running forward in %s", q.dtype)
4548
            out_ret, aux_ctx_tensors = fused_attn_fwd(
4549
4550
4551
4552
4553
4554
4555
4556
4557
4558
4559
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
4560
4561
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4562
4563
4564
4565
4566
4567
4568
4569
4570
4571
4572
4573
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4574
                window_size,
4575
4576
                rng_gen,
            )
4577
4578
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
4579

4580
        from .cpu_offload import CPUOffloadEnabled
4581

4582
        if CPUOffloadEnabled:
4583
            tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
4584
            qkv_layout = "sbhd_sbhd_sbhd"
4585
4586
4587
4588
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

4589
4590
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
4591
4592
4593
4594
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
4595
4596
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4597
4598
4599
            *fp8_tensors,
            *aux_ctx_tensors,
        )
4600
        ctx.fp8_meta = fp8_meta
4601
4602
4603
4604
4605
4606
4607
4608
4609
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
4610
        ctx.window_size = window_size
4611
        ctx.fused_attention_backend = (
4612
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
4613
        )
4614
        ctx.use_FAv2_bwd = use_FAv2_bwd
4615
        ctx.deterministic = deterministic
4616

4617
        return out_ret
4618
4619
4620

    @staticmethod
    def backward(ctx, d_out):
4621
        logger = logging.getLogger("FusedAttnFunc")
4622
        if ctx.fp8_meta["recipe"].fp8_mha:
4623
4624
4625
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
4626
4627
4628
            d_out_f8tensor = d_out
            d_out = d_out._data

4629
        d_out = d_out.contiguous()
4630
4631
4632
4633
4634
4635
4636
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
4637
4638
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4639
4640
4641
4642
4643
4644
4645
4646
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
4647
4648
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
4649
        if ctx.use_FAv2_bwd:
4650
            softmax_lse, rng_state = aux_ctx_tensors
4651
4652
4653
4654
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
4655
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
4656
            flash_attn_cuda_bwd(
4657
4658
4659
4660
4661
4662
4663
4664
4665
4666
4667
4668
4669
4670
4671
4672
4673
4674
4675
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
4676
            )
4677
4678
4679
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
4680
        else:
4681
4682
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
4683
                    logger.debug("Running backward in FP8")
4684
4685
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
4686
4687
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
4688
4689
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
4690
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
4691
4692
4693
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
4694
4695
4696
4697
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
4698
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
4699
4700
4701
4702
4703
4704
4705
4706
4707
4708
4709
4710
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
4711
                        ctx.fused_attention_backend,
4712
4713
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4714
4715
4716
4717
4718
4719
4720
4721
4722
4723
4724
4725
4726
4727
4728
4729
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4730
4731
                        ctx.window_size,
                        ctx.deterministic,
4732
                    )
4733

4734
                    if ctx.fp8_meta["recipe"].fp8_mha:
4735
4736
                        dq = Float8Tensor(
                            data=dq_fp8,
4737
4738
4739
4740
4741
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4742
4743
4744
                        )
                        dk = Float8Tensor(
                            data=dk_fp8,
4745
4746
4747
4748
4749
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4750
4751
4752
                        )
                        dv = Float8Tensor(
                            data=dv_fp8,
4753
4754
4755
4756
4757
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
4758
                        )
4759
                    else:
4760
                        qkv_group = len(ctx.qkv_layout.split("_"))
4761
                        if qkv_group == 1:
4762
4763
4764
4765
4766
4767
4768
4769
4770
4771
4772
4773
4774
                            dim = ctx.qkv_layout.find("3")
                            dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(
                                -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                            )
                            dqkv = cast_from_fp8(
                                dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1])
4775
4776
4777
4778
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
4779
4780
4781
4782
4783
4784
4785
4786
4787
4788
4789
4790
4791
4792
4793
4794
4795
4796
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
                            dkv = cast_from_fp8(
                                dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1])
4797
4798
4799
4800
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
4801
4802
4803
4804
4805
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
4806
4807
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
4808
4809
4810
4811
4812
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dk_fp8.shape)
4813
4814
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
4815
4816
4817
4818
4819
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dv_fp8.shape)
4820
                else:
4821
                    logger.debug("Running backward in %s", q.dtype)
4822
4823
4824
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
4825
4826
4827
4828
4829
4830
4831
4832
4833
4834
4835
4836
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
4837
                        ctx.fused_attention_backend,
4838
4839
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4840
4841
4842
4843
4844
4845
4846
4847
4848
4849
4850
4851
4852
4853
4854
4855
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4856
4857
                        ctx.window_size,
                        ctx.deterministic,
4858
                    )
4859

4860
4861
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
4862
4863
4864
4865
4866
4867
4868
4869
4870
4871
4872
4873
4874
4875
4876
4877
4878
4879
4880
4881
4882
4883
4884
4885
4886
4887
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
4888
4889
                None,
                None,
4890
            )
4891
        # else, return (dqkv, dbias)
4892
4893
4894
4895
4896
4897
4898
4899
4900
4901
4902
4903
4904
4905
4906
4907
4908
4909
4910
4911
4912
4913
4914
4915
4916
4917
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4918
4919
            None,
            None,
4920
        )
4921

4922

4923
class FusedAttention(torch.nn.Module):
4924
4925
4926
4927
4928
4929
4930
4931
4932
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

4933
4934
4935
4936
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
4937
    | attn_type     | self/cross              | self/cross                     |
4938
    | qkv_layout    |                         |                                |
4939
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
4940
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
4941
4942
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
4943
4944
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
4945
    | dropout       | yes                     | yes                            |
4946
4947
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
4948
    | output dtype  | fp16/bf16               | fp16/bf16                      |
4949
4950
4951
4952
    """

    def __init__(
        self,
4953
        softmax_scale: float,
4954
4955
4956
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
4957
4958
        layer_number: Optional[int] = None,
        deterministic: bool = False,
4959
4960
4961
    ) -> None:
        super().__init__()

4962
        self.logger = logging.getLogger("FusedAttention")
4963
        self.softmax_scale = softmax_scale
4964
4965
4966
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
4967
4968
4969
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
4970
        self.layer_number = 1 if layer_number is None else layer_number
4971
        self.deterministic = deterministic
4972

4973
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
4974
4975
            """
            Temporarily remove fused_attention._extra_state as a missing key
4976
4977
4978
4979
            or an unexpected key when loading TransformerEngine checkpoints.
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
            phased out in TransformerEngine 2.0.
4980
4981
            """
            for key in incompatible_keys.missing_keys:
4982
                if "fused_attention._extra_state" in key:
4983
                    incompatible_keys.missing_keys.remove(key)
4984
4985
4986
4987
4988
4989
4990
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
4991

4992
4993
        self.register_load_state_dict_post_hook(remove_extra_states_check)

4994
    @no_torch_dynamo()
4995
4996
4997
4998
4999
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5000
5001
5002
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5003
5004
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
5005
5006
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5007
        attn_mask_type: str = "causal",
5008
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5009
        window_size: Optional[Tuple[int, int]] = None,
5010
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
5011
5012
5013
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
5014
5015
5016
        cp_group: Optional[dist_group_type] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
5017
5018
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
5019
5020
    ) -> torch.Tensor:
        """fused attention fprop"""
5021
5022
5023
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
5024
        assert (
5025
5026
5027
            (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
5028
        ), "FusedAttention only supports FP16 and BF16 data types."
5029
5030
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
5031
        ), "FusedAttention only supports CUDA tensors."
5032
5033
        assert (
            qkv_layout in QKVLayouts
5034
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
5035

5036
5037
        cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group)
        context_parallel = cp_size > 1
5038

5039
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
5040

5041
5042
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
5043
                batch_size, max_seqlen_q, max_seqlen_kv = (
5044
5045
5046
5047
5048
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
5049
                batch_size, max_seqlen_q, max_seqlen_kv = (
5050
5051
5052
5053
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
5054
5055
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
5056
            if "padding" in attn_mask_type:
5057
5058
                assert not context_parallel, "Padding mask not supported with context parallelism!"

5059
5060
5061
5062
5063
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
5064
                    if self.attention_type == "self":
5065
5066
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
5067
                    else:
5068
5069
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
5070
            else:
5071
5072
5073
5074
5075
5076
5077
5078
5079
5080
5081
5082
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
5083
5084
5085
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
5086
5087
5088
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
5089
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
5090
5091
5092
5093

        if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
5094
5095
5096

        qkv_dtype = TE_DType[query_layer.dtype]

5097
5098
5099
5100
5101
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
5102
5103

        if context_parallel:
5104
            assert (
5105
5106
5107
5108
5109
5110
5111
5112
                fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
5113
5114
5115
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
5116
5117
5118
5119
5120
5121
5122
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
5123
5124
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
5125
                    self.attention_dropout if self.training else 0.0,
5126
5127
5128
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
5129
                    softmax_scale=self.softmax_scale,
5130
                    qkv_format=qkv_format,
5131
                    attn_mask_type=attn_mask_type,
5132
5133
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
5134
5135
5136
                    use_fused_attention=True,
                )
        else:
5137
5138
5139
5140
5141
            with self.attention_dropout_ctx():
                if fp8:
                    assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                        f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                        " is required for FP8 attention!"
5142
                    )
5143
5144
5145
5146
5147
5148
5149
5150
5151
                    assert (
                        fp8_meta is not None
                    ), "FP8 metadata fp8_meta is required for FP8 attention!"
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
5152
5153
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
5154
5155
5156
5157
5158
5159
5160
5161
5162
5163
5164
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
5165
                    window_size,
5166
5167
5168
5169
5170
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
5171
                    self.deterministic,
5172
                )
5173

5174
5175
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
5176
5177


5178
class DotProductAttention(TransformerEngineBaseModule):
5179
5180
5181
5182
5183
5184
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

5185
        Argument :attr:`attention_mask` in the `forward` call is only used when
5186
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
5187
5188
5189

    .. warning::

5190
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
5191
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
5192
5193
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
5194
5195
5196
5197
5198

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
5199
5200
5201
5202
    k_channels : int
                number of channels per attention head in key.
    v_channels : Optional[int] = None
                number of channels per attention head in value.
5203
5204
5205
5206
5207
5208
5209
5210
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
5211
5212
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
5213
    attn_mask_type: str, default = `causal`
5214
                   type of attention mask passed into softmax operation, options are "`no_mask`",
5215
5216
5217
5218
5219
5220
5221
5222
5223
5224
5225
5226
5227
5228
5229
5230
5231
5232
5233
5234
5235
5236
5237
5238
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
                   "`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
5239
5240
5241
5242
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
5243
5244
5245
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
5246
                be overridden by :attr:`window_size` in `forward` as well.
5247
5248
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
5249
5250
5251
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
5252
5253
5254
5255
5256
5257
5258
5259
5260
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
               `h` the number of heads, `d` head size, and `t` the total number of sequences
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
5261
               For that, please use `get_qkv_layout` to gain the layout information.
5262
5263
5264
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
                `1.0 / math.sqrt(kv_channels)`.
5265
5266
5267
5268
5269
5270
5271
5272
5273

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
5274
5275
5276
5277
5278
5279
5280
5281
5282
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
5283
5284
5285
5286
5287
    """

    def __init__(
        self,
        num_attention_heads: int,
5288
5289
        k_channels: int,
        v_channels: Optional[int] = None,
5290
        num_gqa_groups: Optional[int] = None,
5291
        attention_dropout: float = 0.0,
5292
        qkv_format: str = "sbhd",
5293
        attn_mask_type: str = "causal",
5294
        window_size: Optional[Tuple[int, int]] = None,
5295
5296
5297
5298
5299
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
5300
        attention_type: str = "self",
5301
        cp_group: Optional[dist_group_type] = None,
5302
        cp_global_ranks: List[int] = None,
5303
        cp_stream: torch.cuda.Stream = None,
5304
        softmax_scale: Optional[float] = None,
5305
5306
5307
    ) -> None:
        super().__init__()

5308
        self.logger = logging.getLogger("DotProductAttention")
5309
        self.qkv_format = qkv_format
5310
        attn_mask_type = attn_mask_type.replace(",", "_")
5311
5312
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
5313
        self.attn_mask_type = attn_mask_type
5314
        self.window_size = check_set_window_size(attn_mask_type, window_size)
5315
5316
5317
5318
5319
5320
5321
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
5322
        self.get_rng_state_tracker = get_rng_state_tracker
5323
        self.num_attention_heads = num_attention_heads
5324
        self.layer_number = 1 if layer_number is None else layer_number
5325
5326
5327
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
5328

5329
5330
        self.hidden_size_per_attention_head = k_channels
        self.v_channels = k_channels if v_channels is None else v_channels
5331

5332
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
5333
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
5334

5335
5336
5337
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
5338

5339
        self.rng_states_tracker = None
5340
5341
5342
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
5343
5344
5345
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
5346

5347
        if softmax_scale is None:
5348
            softmax_scale = 1.0 / math.sqrt(k_channels)
5349

5350
5351
5352
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
5353
        )
5354
5355
5356
5357
5358
5359
5360
5361
5362
5363
5364
5365
5366
5367
5368
5369
5370
5371
5372
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
5373

5374
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
5375
5376
5377
5378

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

5379
5380
5381
5382
5383
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

5384
5385
5386
5387
5388
5389
5390
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
5391

5392
        # Instantiating three types since use of flash-attn and FusedAttention
5393
        # might be ruled out due to forward inputs.
5394
5395
5396
5397
5398
5399
5400
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
5401

5402
        self.unfused_attention = UnfusedDotProductAttention(
5403
5404
            softmax_scale, **attn_kwargs, layer_number=layer_number
        )
5405

5406
5407
5408
5409
5410
5411
5412
5413
5414
5415
5416
5417
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
            when loading older TransformerEngine checkpoints. Will phase out
            this hook in TransformerEngine 2.0.
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

5418
5419
5420
5421
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
5422
        **forward_kwargs: Dict[str, Any],
5423
5424
5425
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

5426
5427
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
5428
5429
5430

        hidden_states = checkpoint(
            custom_forward,
5431
5432
5433
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
5434
            *forward_args,
5435
            **forward_kwargs,
5436
5437
5438
5439
        )

        return hidden_states

5440
5441
5442
5443
5444
5445
    def set_context_parallel_group(
        self,
        cp_group: Union[dist_group_type, None],
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
    ) -> None:
5446
5447
5448
5449
5450
5451
5452
5453
5454
5455
5456
5457
5458
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
5459
5460
5461
5462
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream

5463
    @no_torch_dynamo(recursive=False)
5464
5465
5466
5467
5468
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5469
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5470
5471
5472
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5473
5474
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
5475
5476
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5477
        attn_mask_type: Optional[str] = None,
5478
        window_size: Optional[Tuple[int, int]] = None,
5479
        checkpoint_core_attention: bool = False,
5480
5481
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
5482
        alibi_slopes: Optional[torch.Tensor] = None,
5483
        fast_zero_fill: bool = True,
5484
        inference_params: Optional[InferenceParams] = None,
5485
        is_first_microbatch: Optional[bool] = None,
5486
5487
5488
5489
5490
5491
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

5492
5493
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
5494

5495
5496
        .. note::

5497
5498
5499
5500
5501
5502
5503
5504
5505
5506
5507
5508
5509
5510
5511
5512
5513
5514
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
5515
5516
5517
5518
5519
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
5520

5521
5522
5523
5524
5525
5526
5527
5528
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
5529
5530
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
5531
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
5532
5533
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
5534
5535
5536
5537
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
5538
5539
5540
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
5541
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
5542
5543
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
5544
5545
5546
5547
5548
5549
5550
5551
5552
5553
5554
5555
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
5556
5557
5558
5559
5560
5561
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
5562
5563
5564
5565
5566
5567
5568
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
5569
        window_size: Optional[Tuple[int, int]], default = `None`
5570
                    Sliding window size for local attention.
5571
5572
5573
5574
5575
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
5576
        core_attention_bias_type: str, default = `no_bias`
5577
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
5578
        core_attention_bias: Optional[torch.Tensor], default = `None`
5579
5580
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
5581
5582
5583
5584
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
5585
        fast_zero_fill: bool, default = `True`
5586
                    Whether to use the fast path to set output tensors to 0 or not.
5587
5588
5589
5590
5591
5592
5593
5594
5595
5596
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
5597
5598
5599
5600
5601
5602
5603
5604
5605
5606
5607
5608
5609
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
5610
        """
5611
5612
5613
5614
5615
5616
5617
5618
5619
5620
5621
        with self.prepare_forward(
            query_layer,
            is_first_microbatch,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:

            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
5622
5623
5624
5625
                        self.logger.WARNING(
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
5626
5627
5628
5629
5630
5631
5632
5633
5634
5635
5636

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
5637

5638
5639
5640
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
5641
5642
5643
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
5644
5645
5646
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
5647

5648
5649
5650
5651
5652
5653
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
5654
            assert (
5655
5656
5657
5658
5659
5660
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
5661

5662
5663
5664
5665
            if window_size is None:
                window_size = self.window_size
            window_size = check_set_window_size(attn_mask_type, window_size)

5666
5667
5668
5669
5670
5671
5672
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
5673

5674
5675
            if qkv_format is None:
                qkv_format = self.qkv_format
5676

5677
5678
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
5679

5680
5681
5682
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
5683

5684
5685
5686
5687
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
5688

5689
5690
5691
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
5692

5693
5694
5695
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
5696

5697
5698
5699
5700
5701
5702
5703
5704
5705
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
5706

5707
5708
5709
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
5710

5711
5712
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
5713
5714

            assert (
5715
5716
5717
5718
5719
5720
5721
5722
5723
5724
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
5725
                assert all(
5726
5727
5728
5729
5730
5731
5732
5733
5734
5735
5736
5737
5738
5739
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
                if max_seqlen_q is None:
5740
5741
5742
5743
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
5744
5745
                    max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
                if max_seqlen_kv is None:
5746
5747
5748
5749
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
5750
                    max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
5751
                batch_size = len(cu_seqlens_q) - 1
5752

5753
5754
5755
            cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group)
            context_parallel = cp_size > 1

5756
            if qkv_format in ["sbhd", "bshd"]:
5757
                assert all(
5758
5759
5760
5761
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
5762
                    batch_size = query_layer.shape[1]
5763
5764
                if qkv_format == "bshd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
5765
                    batch_size = query_layer.shape[0]
5766
5767
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
5768
5769
5770
5771
5772
5773
5774
5775
5776
5777
5778
5779
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                        the sequence dimention in 'query_layer'!"""
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                        the sequence dimention in 'key_layer' and 'value_layer'!"""
5780
5781
5782
5783
5784
5785
5786
5787
5788
5789
5790
5791
5792
5793
5794
5795
5796
5797
5798
5799
5800
5801
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        if max_seqlen_q == max_seqlen_kv:
                            cu_seqlens_q = get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                    else:
                        cu_seqlens_q = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
                        cu_seqlens_kv = _get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
5802

5803
5804
5805
5806
5807
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
5808
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout(
5809
5810
5811
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
5812
                qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout(
5813
5814
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
5815

5816
5817
5818
5819
5820
5821
5822
5823
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
5824
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
5825
5826
5827
5828
5829
5830
5831
5832
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
5833
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
5834
5835
5836
5837
5838
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

5839
5840
            core_attention_bias_shape = None
            if core_attention_bias is not None:
5841
                if (
5842
5843
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
5844
                ):
5845
5846
5847
5848
5849
5850
5851
5852
5853
5854
5855
5856
5857
5858
5859
5860
5861
5862
5863
5864
5865
5866
5867
5868
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

            pad_between_seqs = (
                cu_seqlens_q_padded is not None
                and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
            ) or (
                cu_seqlens_kv_padded is not None
                and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
            )
5869

5870
            attention_params = AttentionParams(
5871
5872
5873
5874
5875
5876
5877
5878
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
5879
5880
                head_dim_qk=query_layer.shape[-1],
                head_dim_v=value_layer.shape[-1],
5881
5882
5883
5884
5885
5886
5887
5888
5889
5890
5891
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
5892
5893
                deterministic=self.deterministic,
                is_training=self.training,
5894
5895
5896
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
5897
5898
5899
5900
5901
5902
5903
5904
5905
5906
5907
5908
5909
5910
5911
5912
5913
5914
5915
5916
5917
            global _attention_backends
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
                ) = get_attention_backend(attention_params)
                if use_flash_attention:
                    self.logger.info("Running with FlashAttention backend")
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
5918
                    )
5919
5920
5921
5922
5923
5924
5925
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
5926

5927
5928
5929
5930
5931
5932
5933
5934
5935
5936
5937
5938
5939
5940
5941
5942
5943
5944
5945
5946
5947
5948
5949
5950
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
5951
                )
5952

5953
            if use_fused_attention:
5954
5955
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
5956
5957
5958
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
5959
5960
5961
5962
5963
5964
5965
                    fu_core_attention_bias_type = "post_scale_bias"
                    _, fu_core_attention_bias = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
5966
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
5967
                    )
5968
5969
5970
5971
5972
5973
5974
5975
5976
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
5977
5978
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
5979
5980
5981
5982
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
5983
                        window_size=window_size,
5984
5985
5986
5987
5988
5989
5990
5991
5992
5993
5994
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
                    )
                return self.fused_attention(
5995
5996
5997
5998
5999
6000
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
6001
6002
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
6003
6004
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
6005
6006
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
6007
                    window_size=window_size,
6008
                    fused_attention_backend=fused_attention_backend,
6009
6010
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
6011
6012
6013
6014
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
6015
6016
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
6017
                )
6018

6019
            from .cpu_offload import CPUOffloadEnabled
6020

6021
6022
6023
6024
6025
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
6026

6027
            if use_unfused_attention:
6028
6029
6030
6031
6032
6033
                if window_size is not None and (
                    window_size[0] != -1 or window_size[1] not in [-1, 0]
                ):
                    attn_mask_type, attention_mask = get_swa_mask(
                        window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
                    )
6034
6035
6036
6037
6038
6039
6040
6041
6042
6043
6044
6045
6046
6047
6048
6049
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
6050
6051
6052
                    query_layer,
                    key_layer,
                    value_layer,
6053
6054
6055
6056
6057
6058
6059
6060
6061
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
6062

6063
            raise Exception("No dot product attention support for the provided inputs!")
6064
6065


6066
6067
6068
6069
6070
6071
6072
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

6073
6074
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
6075

6076
6077
6078
6079
6080
6081
6082
6083
6084
6085
6086
6087
6088
6089
6090
6091
6092
6093
6094
6095
6096
6097
6098
6099
6100
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
6101
6102
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
6103
                   default = `causal`
6104
6105
6106
6107
6108
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
6109
6110
6111
6112
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
6113
6114
6115
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
6116
                be overridden by :attr:`window_size` in `forward` as well.
6117
6118
6119
6120
6121
6122
6123
6124
6125
6126
6127
6128
6129
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
6130
6131
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
6132
6133
6134
6135
6136
6137
6138
6139
6140
6141
6142
6143
6144
6145
6146
6147
6148
6149
6150
6151
6152
6153
6154
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
6155
6156
6157
6158
6159
6160
6161
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
6162
            For that, please use `get_qkv_layout` to gain the layout information.
6163
6164
6165
6166
6167
6168
6169
6170
6171
6172
6173
6174
6175
6176
6177
6178
6179
6180
6181
6182
6183
6184
6185
6186
6187
6188
6189
6190
6191
6192
6193
6194
6195
6196
6197
6198
6199
6200
6201
6202

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
6203
6204
6205
6206
6207
6208
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
6209
6210
6211
6212
6213
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
6214
        layer_number: Optional[int] = None,
6215
        attn_mask_type: str = "causal",
6216
        window_size: Optional[Tuple[int, int]] = None,
6217
6218
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
6219
        num_gqa_groups: Optional[int] = None,
6220
6221
6222
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
6223
        params_dtype: Optional[torch.dtype] = None,
6224
        return_bias: bool = False,
6225
6226
6227
6228
6229
6230
6231
6232
6233
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
6234
        ub_overlap_rs_dgrad: bool = False,
6235
6236
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
6237
        bias: bool = True,
6238
        normalization: str = "LayerNorm",
6239
        device: Union[torch.device, str] = "cuda",
6240
        qkv_format: str = "sbhd",
6241
6242
    ) -> None:
        super().__init__()
6243

6244
        self.qkv_format = qkv_format
6245
        self.attn_mask_type = attn_mask_type
6246
        self.window_size = check_set_window_size(attn_mask_type, window_size)
6247
        self.layer_number = layer_number
6248
6249
6250
6251
6252
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
6253
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
6254
        self.num_attention_heads = num_attention_heads
6255
6256
6257
6258
6259
6260
6261
6262
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
6263
6264
6265
6266
6267

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

6268
6269
6270
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
6271
6272
6273
6274
6275
6276

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
6277
6278
6279
6280
6281
6282
6283
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
6284
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
6285
6286
6287
6288

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
6289
6290
6291
6292
6293
6294
6295

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
6296
            "params_dtype": self.params_dtype,
6297
            "device": device,
6298
6299
6300
6301
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
6302
        if self.attention_type == "self":
6303
6304
            parameters_split = None
            if not fuse_qkv_params:
6305
6306
6307
6308
6309
6310
6311
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
6312
6313
6314
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
6315
                    self.hidden_size_q + 2 * self.hidden_size_kv,
6316
6317
6318
6319
6320
6321
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
6322
                    parameters_split=parameters_split,
6323
6324
6325
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
6326
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
6327
                    ub_overlap_ag=ub_overlap_ag,
6328
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6329
                    ub_name="qkv",
6330
6331
6332
6333
6334
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
6335
                    self.hidden_size_q + 2 * self.hidden_size_kv,
6336
6337
6338
6339
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
6340
                    parameters_split=parameters_split,
6341
6342
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
6343
        elif self.attention_type == "cross":
6344
6345
6346
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
6347
                    self.hidden_size_q,
6348
6349
6350
6351
6352
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
6353
                    parameters_split=("query",) if not fuse_qkv_params else None,
6354
6355
6356
6357
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
6358
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
6359
                    ub_overlap_ag=ub_overlap_ag,
6360
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6361
                    ub_name="qkv",
6362
6363
6364
6365
6366
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
6367
                    self.hidden_size_q,
6368
6369
6370
6371
6372
6373
6374
6375
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
6376
                2 * self.hidden_size_kv,
6377
6378
6379
6380
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
6381
                parameters_split=("key", "value") if not fuse_qkv_params else None,
6382
6383
6384
6385
6386
6387
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
6388
            self.hidden_size_per_attention_head,
6389
6390
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
6391
            qkv_format=self.qkv_format,
6392
6393
6394
6395
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
6396
            layer_number=self.layer_number,
6397
            attention_type=self.attention_type,
6398
6399
6400
6401
        )

        # Linear
        self.proj = Linear(
6402
            self.hidden_size_q,
6403
6404
6405
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
6406
            return_bias=return_bias,
6407
            parallel_mode="row" if set_parallel_mode else None,
6408
6409
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6410
            ub_name="proj",
6411
6412
6413
6414
            **common_gemm_kwargs,
        )

    def _allocate_memory(
6415
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
6416
6417
6418
6419
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
6420
            self.num_gqa_groups_per_partition,
6421
            self.hidden_size_per_attention_head,
6422
            dtype=dtype,
6423
6424
6425
6426
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
6427
6428
6429
6430
6431
6432
6433
6434
6435
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
6436
6437
        self.tp_group = tp_group

6438
    def set_context_parallel_group(
6439
6440
        self,
        cp_group: Union[dist_group_type, None],
6441
        cp_global_ranks: List[int],
6442
6443
        cp_stream: torch.cuda.Stream,
    ) -> None:
6444
6445
6446
6447
6448
6449
6450
6451
6452
6453
6454
6455
6456
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
6457
6458
6459
6460
6461
6462
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
6463

6464
6465
6466
    def forward(
        self,
        hidden_states: torch.Tensor,
6467
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6468
        encoder_output: Optional[torch.Tensor] = None,
6469
        attn_mask_type: Optional[str] = None,
6470
        window_size: Optional[Tuple[int, int]] = None,
6471
6472
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
6473
        inference_params: Optional[InferenceParams] = None,
6474
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6475
6476
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
6477
        alibi_slopes: Optional[torch.Tensor] = None,
6478
        fast_zero_fill: bool = True,
6479
    ) -> Tuple[Union[torch.Tensor, None], ...]:
6480
6481
6482
6483
6484
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

6485
6486
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
6487
6488
6489
6490
6491

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
6492
6493
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
6494
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
6495
6496
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
6497
6498
6499
6500
6501
6502
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
6503
                       default = `None`
6504
6505
6506
6507
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
6508
6509
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
6510
6511
6512
6513
6514
6515
6516
6517
6518
6519
6520
6521
6522
6523
6524
6525
6526
6527
6528
6529
6530
6531
6532
6533
6534
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
6535
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
6536
        core_attention_bias: Optional[torch.Tensor], default = `None`
6537
6538
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
6539
6540
6541
6542
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
6543
6544
6545
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
6546
6547
        # hidden_states: [sq, b, h]

6548
        if attn_mask_type is None:
6549
            attn_mask_type = self.attn_mask_type
6550
6551
        if window_size is None:
            window_size = self.window_size
6552
        window_size = check_set_window_size(attn_mask_type, window_size)
6553

6554
        if "padding" in attn_mask_type and attention_mask is not None:
6555
            for i, _ in enumerate(attention_mask):
6556
6557
6558
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
6559

6560
6561
6562
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
6563

6564
        # =================================================
6565
        # Pre-allocate memory for key-values for inference
6566
6567
6568
6569
        # =================================================

        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
6570
                inf_max_seq_len = inference_params.max_sequence_length
6571
6572
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
6573
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
6574
6575
                )
                inference_value_memory = self._allocate_memory(
6576
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
6577
6578
6579
6580
6581
6582
6583
6584
6585
6586
6587
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

6588
        # ======================
6589
        # Query, Key, and Value
6590
        # ======================
6591

cyanguwa's avatar
cyanguwa committed
6592
6593
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
6594
6595
6596
6597
6598
6599
6600
6601
6602
6603
6604
6605
6606
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6607
                    is_first_module_in_mha=True,  # specific to FP8 MHA
6608
6609
                )

6610
6611
6612
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
6613
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
6614
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
6615
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
6616
6617
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
6618
6619
6620
6621
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
6622
6623
6624
6625
6626
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
6627
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
6628
6629
6630
                )
                # split along third last dimension
                split_dim = -3
6631
6632
6633

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6634
6635
6636
6637
6638
6639
6640
6641
6642
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
6643
                )
6644
            else:
cyanguwa's avatar
cyanguwa committed
6645
                query_layer, key_layer, value_layer = torch.split(
6646
6647
6648
6649
                    mixed_x_layer,
                    (num_queries_per_key_value, 1, 1),
                    dim=split_dim,
                )
cyanguwa's avatar
cyanguwa committed
6650
6651
6652

            # query: -> [sq, b, np, hn]
            # key, value: -> [sq, b, ng, hn]
6653
6654
6655
6656
            query_layer, key_layer, value_layer = (
                x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                for x in (query_layer, key_layer, value_layer)
            )
cyanguwa's avatar
cyanguwa committed
6657
6658
6659

        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
6660
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
6661
                encoder_output,
6662
                is_first_microbatch=is_first_microbatch,
6663
                is_first_module_in_mha=True,  # specific to FP8 MHA
6664
6665
6666
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
6667
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
6668
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
6669
                    self.num_gqa_groups_per_partition,
6670
6671
6672
6673
6674
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
6675
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
6676
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
6677
                    2 * self.num_gqa_groups_per_partition,
6678
6679
6680
6681
6682
6683
6684
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6685
6686
6687
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
6688
6689
6690
                    mixed_kv_layer,
                    split_dim,
                    mixed_kv_layer.shape[split_dim] // 2,
cyanguwa's avatar
cyanguwa committed
6691
                )
6692
            else:
cyanguwa's avatar
cyanguwa committed
6693
                key_layer, value_layer = torch.split(
6694
6695
6696
                    mixed_kv_layer,
                    mixed_kv_layer.shape[split_dim] // 2,
                    dim=split_dim,
cyanguwa's avatar
cyanguwa committed
6697
                )
6698
6699
6700
6701
6702
6703
6704
6705
6706
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
6707
6708
6709
6710
6711
6712
6713
6714
6715
6716
6717
6718
6719
6720
6721

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6722
                    is_first_module_in_mha=True,  # specific to FP8 MHA
6723
6724
6725
6726
6727
6728
6729
6730
6731
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

6732
6733
6734
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
6735

6736
        if rotary_pos_emb is not None:
6737
6738
6739
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
6740
            # duplicate the pos_emb for self attention
6741
            if not isinstance(rotary_pos_emb, tuple):
6742
                rotary_pos_emb = (rotary_pos_emb,) * 2
6743
6744

            q_pos_emb, k_pos_emb = rotary_pos_emb
6745
6746
6747
6748
6749
6750
6751
6752
6753
6754
6755
6756
6757
6758

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

6759
6760
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
6761

6762
6763
6764
6765
        # ===========================
        # Core attention computation
        # ===========================

6766
6767
6768
6769
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
6770
            qkv_format=self.qkv_format,
6771
6772
            cu_seqlens_q=None,
            cu_seqlens_kv=None,
6773
6774
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
6775
            window_size=window_size,
6776
6777
6778
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
6779
            alibi_slopes=alibi_slopes,
6780
            fast_zero_fill=fast_zero_fill,
6781
            inference_params=inference_params,
6782
6783
        )

6784
        # ===================
6785
        # Output. [sq, b, h]
6786
        # ===================
6787

6788
        projection_output = self.proj(
6789
6790
            context_layer,
            is_first_microbatch=is_first_microbatch,
6791
6792
        )

6793
6794
6795
6796
6797
6798
6799
6800
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
6801
        if self.input_layernorm and self.return_layernorm_output:
6802
6803
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]